diff --git a/README.md b/README.md
index 44644c0def69245cf1e7d20e921c928e36e0f6d9..8264706fc779a73624707592d59f3b8257ea32be 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,122 @@
----
-title: Sat3density
-emoji: 🏆
-colorFrom: green
-colorTo: blue
-sdk: gradio
-sdk_version: 3.41.2
-app_file: app.py
-pinned: false
-license: other
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs
+
+> [Ming Qian](https://qianmingduowan.github.io/), Jincheng Xiong, [Gui-Song Xia](http://www.captain-whu.com/xia_En.html), [Nan Xue](https://xuenan.net)
+>
+> IEEE/CVF International Conference on Computer Vision (ICCV), 2023
+>
+> [Project](https://sat2density.github.io/) | [Paper](https://arxiv.org/abs/2303.14672) | [Data]() | [Install.md](docs/INSTALL.md)
+
+> <p align="center" float="left">
+>    <img src="docs/figures/demo/case1.sat.gif" alt="drawing" width="19%">  
+>    <img src="docs/figures/demo-density/case1.gif" alt="drawing" width="38%">
+>    <img src="docs/figures/demo/case1.render.gif" alt="drawing" width="38%">
+> </p>
+
+> <p align="center" float="left">
+>    <img src="docs/figures/demo/case2.sat.gif" alt="drawing" width="19%">  
+>    <img src="docs/figures/demo-density/case2.gif" alt="drawing" width="38%">
+>    <img src="docs/figures/demo/case2.render.gif" alt="drawing" width="38%">
+> </p>
+
+> <p align="center" float="left">
+>    <img src="docs/figures/demo/case3.sat.gif" alt="drawing" width="19%">  
+>    <img src="docs/figures/demo-density/case3.gif" alt="drawing" width="38%">
+>    <img src="docs/figures/demo/case3.render.gif" alt="drawing" width="38%">
+> </p>
+
+> <p align="center" float="left">
+>    <img src="docs/figures/demo/case4.sat.gif" alt="drawing" width="19%">  
+>    <img src="docs/figures/demo-density/case4.gif" alt="drawing" width="38%">
+>    <img src="docs/figures/demo/case4.render.gif" alt="drawing" width="38%">
+> </p>
+
+## Checkpoints Downloading
+> Two checkpoints for CVACT and CVUSA can be found from [thisurl](https://github.com/sat2density/checkpoints/releases). You can also run the following command to download them.
+```
+bash scripts/download_weights.sh
+```
+
+## QuickStart Demo
+### Video Synthesis
+  #### Example Usage
+  ```
+  python test.py --yaml=sat2density_cvact \
+    --test_ckpt_path=2u87bj8w \
+    --task=test_vid \
+    --demo_img=demo_img/case1/satview-input.png  \
+    --sty_img=demo_img/case1/groundview.image.png  \
+    --save_dir=results/case1
+  ```
+  #### 
+
+### Illumination Interpolation
+<!-- ```
+bash inference/quick_demo_interpolation.sh
+``` -->
+```
+python test.py --task=test_interpolation \
+--yaml=sat2density_cvact \
+--test_ckpt_path=2u87bj8w \
+--sty_img1=demo_img/case9/groundview.image.png \
+--sty_img2=demo_img/case7/groundview.image.png \
+--demo_img=demo_img/case3/satview-input.png \
+--save_dir=results/case2
+```
+
+## Train & Inference
+- *We trained our model using 1 V100 32GB GPU. The training phase will take about 20 hours.*
+- *For data preparation, please check out [data.md](dataset/INSTALL.md).*
+
+
+
+
+### Inference
+
+To test Center Ground-View Synthesis setting
+If you want save results, please add --task=vis_test
+```bash
+# CVACT
+python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w
+# CVUSA
+python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4
+```
+
+To test inference with different illumination
+```bash
+# CVACT
+bash inference/single_style_test_cvact.sh
+# CVUSA
+bash inference/single_style_test_cvusa.sh
+```
+
+To test synthesis ground videos
+```bash
+bash inference/synthesis_video.sh
+```
+
+## Training
+
+### Training command
+
+```bash
+# CVACT
+CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvact
+# CVUSA
+CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvusa
+```
+
+## Citation
+If you use this code for your research, please cite
+
+```
+@inproceedings{qian2021sat2density,
+  title={Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs},
+  author={Qian, Ming and Xiong, Jincheng and Xia, Gui-Song and Xue, Nan},
+  booktitle={ICCV},
+  year={2023}
+}
+```
+
+## License
+This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
+For commercial use, please contact [mingqian@whu.edu.cn].
diff --git a/__pycache__/options.cpython-38.pyc b/__pycache__/options.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae86524e6a0305cbf3f915a831f72c3d7e62ade3
Binary files /dev/null and b/__pycache__/options.cpython-38.pyc differ
diff --git a/__pycache__/test.cpython-38.pyc b/__pycache__/test.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c26acf14978dd41cdcadaf20b1fec863948c7f3
Binary files /dev/null and b/__pycache__/test.cpython-38.pyc differ
diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c490c388bbf3dd8d18683faeebf085e3623d0b76
Binary files /dev/null and b/__pycache__/utils.cpython-38.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e56f1de13f4647636798b3e51522d193342b9ea4
--- /dev/null
+++ b/app.py
@@ -0,0 +1,244 @@
+import gradio as gr
+import numpy as np
+import os
+from PIL import Image
+import torch
+import torchvision.transforms as transforms
+import options
+import test
+import importlib
+from scipy.interpolate import interp1d, splev, splprep
+import cv2
+
+
+def get_single(sat_img, style_img, x_offset, y_offset):
+    name = ''
+    for i in [name for name in os.listdir('demo_img') if 'case' in name]:
+        style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB')
+        style =np.array(style)
+        if (style == style_img).all():
+            name = i
+            break
+
+    input_dict = {}
+    trans = transforms.ToTensor()
+    input_dict['sat'] = trans(sat_img)
+    input_dict['pano'] = trans(style_img)
+    input_dict['paths'] = "demo.png"
+    sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L"))
+    input_a = input_dict['pano']*sky
+    sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
+    input_dict['sky_histc'] = sky_histc
+    input_dict['sky_mask'] = sky
+
+    for key in input_dict.keys():
+        if isinstance(input_dict[key], torch.Tensor):
+            input_dict[key] = input_dict[key].unsqueeze(0)
+
+    args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png",
+            "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"]
+    opt_cmd = options.parse_arguments(args=args)
+    opt = options.set(opt_cmd=opt_cmd)
+    opt.isTrain = False
+    opt.name = opt.yaml if opt.name is None else opt.name
+    opt.batch_size = 1
+
+    m = importlib.import_module("model.{}".format(opt.model))
+    model = m.Model(opt)
+
+    # m.load_dataset(opt)
+    model.build_networks(opt)
+    ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
+    model.netG.load_state_dict(ckpt['netG'])
+    model.netG.eval()
+
+    model.set_input(input_dict)
+    
+    model.style_temp = model.sky_histc
+    opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future
+
+    model.forward(opt)
+
+    rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0))
+    rgb = np.array(rgb*255, dtype=np.uint8)
+    return  rgb
+
+def get_video(sat_img, style_img, positions):
+    name = ''
+    for i in [name for name in os.listdir('demo_img') if 'case' in name]:
+        style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB')
+        style =np.array(style)
+        if (style == style_img).all():
+            name = i
+            break
+
+    input_dict = {}
+    trans = transforms.ToTensor()
+    input_dict['sat'] = trans(sat_img)
+    input_dict['pano'] = trans(style_img)
+    input_dict['paths'] = "demo.png"
+    sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L"))
+    input_a = input_dict['pano']*sky
+    sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
+    input_dict['sky_histc'] = sky_histc
+    input_dict['sky_mask'] = sky
+
+    for key in input_dict.keys():
+        if isinstance(input_dict[key], torch.Tensor):
+            input_dict[key] = input_dict[key].unsqueeze(0)
+
+    args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png",
+            "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"]
+    opt_cmd = options.parse_arguments(args=args)
+    opt = options.set(opt_cmd=opt_cmd)
+    opt.isTrain = False
+    opt.name = opt.yaml if opt.name is None else opt.name
+    opt.batch_size = 1
+
+    m = importlib.import_module("model.{}".format(opt.model))
+    model = m.Model(opt)
+
+    # m.load_dataset(opt)
+    model.build_networks(opt)
+    ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
+    model.netG.load_state_dict(ckpt['netG'])
+    model.netG.eval()
+
+    model.set_input(input_dict)
+    
+    model.style_temp = model.sky_histc
+
+    unique_lst = list(dict.fromkeys(positions))
+    pixels = []
+    for x in positions:
+        if x in unique_lst:
+            if x not in pixels:
+                pixels.append(x)
+    pixels = np.array(pixels)
+    tck, u = splprep(pixels.T, s=25, per=0)
+    u_new = np.linspace(u.min(), u.max(), 80)
+    x_new, y_new = splev(u_new, tck)
+    smooth_path = np.array([x_new,y_new]).T
+
+    rendered_image_list = []
+    rendered_depth_list = []
+
+
+    for i, (x,y) in enumerate(smooth_path):
+        opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future
+        print('Rendering at ({}, {})'.format(x,y))
+        model.forward(opt)
+
+        rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0))
+        rgb = np.array(rgb*255, dtype=np.uint8)
+        rendered_image_list.append(rgb)
+
+        rendered_depth_list.append(
+            model.out_put.depth[0,0].cpu().detach().numpy()
+        )
+
+    output_video_path = 'output_video.mp4'
+
+    # 设置视频的帧率、宽度和高度
+    frame_rate = 15
+    frame_width = 512
+    frame_height = 128
+
+    # 使用OpenCV创建视频写入对象,选择H.264编码器
+    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+    out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height))
+
+    # 遍历图像列表并将它们写入视频
+    for image_np in rendered_image_list:
+        image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
+        out.write(image_np)
+
+    # 释放视频写入对象
+    out.release()
+
+    return "output_video.mp4"
+
+def copy_image(image):
+    return image
+
+def show_image_and_point(image, x, y):
+    x = int(x*image.shape[1])
+    y = image.shape[0]-int(y*image.shape[0])
+    mask = np.zeros(image.shape[:2])
+    radius = min(image.shape[0], image.shape[1])//60
+    for i in range(x-radius-2, x+radius+2):
+        for j in range(y-radius-2, y+radius+2):
+            if (i-x)**2+(j-y)**2<=radius**2:
+                mask[j, i] = 1
+    return (image, [(mask, 'render point')])
+
+def add_select_point(image, evt: gr.SelectData, state1):
+    if state1 == None:
+        state1 = []
+    x, y = evt.index
+    state1.append((x, y))
+    print(state1)
+    radius = min(image.shape[0], image.shape[1])//60
+    for i in range(x-radius-2, x+radius+2):
+        for j in range(y-radius-2, y+radius+2):
+            if (i-x)**2+(j-y)**2<=radius**2:
+                image[j, i, :] = 0
+    return image, state1
+
+def reset_select_points(image):
+    return image, []
+
+
+
+
+
+
+with gr.Blocks() as demo:    
+    gr.Markdown("# Sat2Density Demos")
+    gr.Markdown("### select/upload the satllite image and select the style image")
+    with gr.Row():
+        with gr.Column():
+            sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True)
+            img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i],
+                inputs=sat_img, outputs=None, examples_per_page=20)
+        with gr.Column():
+            style_img = gr.Image()
+            style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i],
+                inputs=style_img, outputs=None, examples_per_page=20)
+
+
+    gr.Markdown("### select a certain point to generate single groundview image")
+    with gr.Row():
+        with gr.Column():
+            with gr.Row():
+                with gr.Column():
+                    slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position")
+                    slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position")
+                    btn_single = gr.Button(label="demo1")
+                
+                annotation_image = gr.AnnotatedImage()
+    
+        out_single = gr.Image()
+        
+    gr.Markdown("### draw a trajectory on the map to generate video")
+    state_select_points = gr.State()
+    with gr.Row():
+        with gr.Column():
+            draw_img = gr.Image(shape=[256, 256], interactive=True)  
+        with gr.Column():
+            out_video = gr.Video()
+            reset_btn =gr.Button(value="Reset")
+            btn_video = gr.Button(label="demo1")
+
+    sat_img.change(copy_image, inputs = sat_img, outputs=draw_img)
+
+    draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points])
+    sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image)
+    slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden')
+    slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden')
+    btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single)
+    reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points])
+    btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发
+
+
+demo.launch()
\ No newline at end of file
diff --git a/data/CVACT_Shi.py b/data/CVACT_Shi.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b4588ccbe61427e4bec55e7d32bb59428c1228f
--- /dev/null
+++ b/data/CVACT_Shi.py
@@ -0,0 +1,119 @@
+import torch,os
+from torch.utils.data.dataset import Dataset
+from PIL import Image
+import scipy.io as sio
+import torchvision.transforms as transforms
+
+def data_list(img_root,mode):
+    exist_aer_list = os.listdir(os.path.join(img_root , 'satview_correct'))
+    exist_grd_list = os.listdir(os.path.join(img_root , 'streetview'))
+    allDataList = os.path.join(img_root, 'ACT_data.mat')
+    anuData = sio.loadmat(allDataList)
+
+    all_data_list = []
+    for i in range(0, len(anuData['panoIds'])):
+        grd_id_align = anuData['panoIds'][i] + '_grdView.png'
+        sat_id_ori = anuData['panoIds'][i] + '_satView_polish.png'
+        all_data_list.append([grd_id_align, sat_id_ori])
+
+    data_list = []
+    
+    if mode=='train':
+        training_inds = anuData['trainSet']['trainInd'][0][0] - 1
+        trainNum = len(training_inds)
+        for k in range(trainNum):
+            data_list.append(all_data_list[training_inds[k][0]])
+    else:
+        val_inds = anuData['valSet']['valInd'][0][0] - 1
+        valNum = len(val_inds)
+        for k in range(valNum):
+            data_list.append(all_data_list[val_inds[k][0]])
+
+
+    pano_list = [img_root + 'streetview/' + item[0] for item in data_list if item[0] in exist_grd_list and item[1] in exist_aer_list]
+
+    return pano_list
+    
+def img_read(img,size=None,datatype='RGB'):
+    img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
+    if size:
+        if type(size) is int:
+            size = (size,size)
+        img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
+    img = transforms.ToTensor()(img)
+    return img
+
+
+class Dataset(Dataset):
+    def __init__(self, opt,split='train',sub=None,sty_img=None):
+        if sty_img:
+            assert sty_img.endswith('grdView.png')
+            demo_img_path = os.path.join(opt.data.root,'streetview',sty_img)
+            self.pano_list = [demo_img_path]
+
+        elif opt.task in  ['test_vid','test_interpolation'] :
+            demo_img_path = os.path.join(opt.data.root,'streetview',opt.demo_img.replace('satView_polish.png','grdView.png'))
+            self.pano_list = [demo_img_path]
+
+        else:
+            self.pano_list = data_list(img_root=opt.data.root,mode=split)
+            if sub:
+                self.pano_list = self.pano_list[:sub]
+        
+        # select some ground images to test the influence of different skys.
+        # different skys guide different illumination intensity, colors, and etc.
+        if opt.task == 'test_sty':
+            demo_name = [
+                'dataset/CVACT/streetview/pPfo7qQ1fP_24rXrJ2Uxog_grdView.png',
+                'dataset/CVACT/streetview/YL81FiK9PucIvAkr1FHkpA_grdView.png',
+                'dataset/CVACT/streetview/Tzis1jBKHjbXiVB2oRYwAQ_grdView.png',
+                'dataset/CVACT/streetview/eqGgeBLGXRhSj6c-0h0KoQ_grdView.png',
+                'dataset/CVACT/streetview/pdZmLHYEhe2PHj_8-WHMhw_grdView.png',
+                'dataset/CVACT/streetview/ehsu9Q3iTin5t52DM-MwyQ_grdView.png',
+                'dataset/CVACT/streetview/agLEcuq3_-qFj7wwGbktVg_grdView.png',
+                'dataset/CVACT/streetview/HwQIDdMI3GfHyPGtCSo6aA_grdView.png',
+                'dataset/CVACT/streetview/hV8svb3ZVXcQ0AtTRFE1dQ_grdView.png',
+                'dataset/CVACT/streetview/fzq2mBfKP3UIczAd9KpMMg_grdView.png',
+                'dataset/CVACT/streetview/acRP98sACUIlwl2ZIsEyiQ_grdView.png',
+                'dataset/CVACT/streetview/WSh9tNVryLdupUlU0ri2tQ_grdView.png',
+                'dataset/CVACT/streetview/FhEuB9NA5o08VJ_TBCbHjw_grdView.png',
+                'dataset/CVACT/streetview/YHfpn2Mgu1lqgT2OUeBpOg_grdView.png',
+                'dataset/CVACT/streetview/vNhv7ZP1dUkJ93UwFXagJw_grdView.png',
+            ]
+            self.pano_list = demo_name
+
+        self.opt = opt
+
+    def __len__(self):
+        return len(self.pano_list)
+
+    def __getitem__(self, index):
+        pano = self.pano_list[index]
+        aer = pano.replace('streetview','satview_correct').replace('_grdView','_satView_polish')
+        if self.opt.data.sky_mask:
+            sky = pano.replace('streetview','pano_sky_mask')
+        name = pano
+        aer = img_read(aer,  size = self.opt.data.sat_size)
+        pano = img_read(pano,size = self.opt.data.pano_size)
+        if self.opt.data.sky_mask:
+            sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')
+
+        input = {}
+        input['sat']=aer
+        input['pano']=pano
+        input['paths']=name
+        if self.opt.data.sky_mask:
+            input['sky_mask']=sky
+            black_ground = torch.zeros_like(pano)
+            if self.opt.data.histo_mode =='grey':
+                input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:] 
+            elif self.opt.data.histo_mode in ['rgb','RGB']:
+                input_a  = (pano*sky+black_ground*(1-sky))
+                for idx in range(len(input_a)):
+                    if idx == 0:
+                        sky_histc = input_a[idx].histc()[10:]
+                    else:
+                        sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
+                input['sky_histc'] = sky_histc
+        return input
+
diff --git a/data/CVUSA.py b/data/CVUSA.py
new file mode 100644
index 0000000000000000000000000000000000000000..2179427bf165335df33ab0761ea28975f7672a45
--- /dev/null
+++ b/data/CVUSA.py
@@ -0,0 +1,86 @@
+import torch,os
+from torch.utils.data.dataset import Dataset
+from PIL import Image
+import torchvision.transforms as transforms
+import re
+from easydict import EasyDict as edict
+
+def data_list(img_root,mode):
+    data_list=[]
+    if mode=='train':
+        split_file=os.path.join(img_root, 'splits/train-19zl.csv')
+        with open(split_file) as f:
+            list = f.readlines()
+            for i in list:
+                aerial_name=re.split(r',', re.split('\n', i)[0])[0]
+                panorama_name = re.split(r',', re.split('\n', i)[0])[1]
+                data_list.append([aerial_name, panorama_name])
+    else:
+        split_file=os.path.join(img_root+'splits/val-19zl.csv')
+        with open(split_file) as f:
+            list = f.readlines()
+            for i in list:
+                aerial_name=re.split(r',', re.split('\n', i)[0])[0]
+                panorama_name = re.split(r',', re.split('\n', i)[0])[1]
+                data_list.append([aerial_name, panorama_name])
+    print('length of dataset is: ', len(data_list))
+    return [os.path.join(img_root, i[1]) for i in data_list]
+    
+def img_read(img,size=None,datatype='RGB'):
+    img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
+    if size:
+        if type(size) is int:
+            size = (size,size)
+        img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
+    img = transforms.ToTensor()(img)
+    return img
+
+
+class Dataset(Dataset):
+    def __init__(self, opt,split='train',sub=None,sty_img=None):
+        self.pano_list = data_list(img_root=opt.data.root,mode=split)
+        if sub:
+            self.pano_list = self.pano_list[:sub]
+        if opt.task == 'test_vid':
+            demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.demo_img)
+            self.pano_list = [demo_img_path]
+        if sty_img:
+            assert opt.sty_img.split('.')[-1] == 'jpg'
+            demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.sty_img)
+            self.pano_list = [demo_img_path]
+
+        self.opt = opt
+
+    def __len__(self):
+        return len(self.pano_list)
+
+    def __getitem__(self, index):
+        pano = self.pano_list[index]
+        aer = pano.replace('streetview/panos', 'bingmap/19')
+        if self.opt.data.sky_mask:
+            sky = pano.replace('streetview/panos','sky_mask').replace('jpg', 'png')
+        name = pano
+        aer = img_read(aer,  size = self.opt.data.sat_size)
+        pano = img_read(pano,size = self.opt.data.pano_size)
+        if self.opt.data.sky_mask:
+            sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')
+
+        input = {}
+        input['sat']=aer
+        input['pano']=pano
+        input['paths']=name
+        if self.opt.data.sky_mask:
+            input['sky_mask']=sky
+            black_ground = torch.zeros_like(pano)
+            if self.opt.data.histo_mode =='grey':
+                input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:] 
+            elif self.opt.data.histo_mode in ['rgb','RGB']:
+                input_a  = (pano*sky+black_ground*(1-sky))
+                for idx in range(len(input_a)):
+                    if idx == 0:
+                        sky_histc = input_a[idx].histc()[10:]
+                    else:
+                        sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
+                input['sky_histc'] = sky_histc
+        return input
+
diff --git a/dataset/INSTALL.md b/dataset/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d0784b008b078f757c910b16b08edc08a7ed6f8
--- /dev/null
+++ b/dataset/INSTALL.md
@@ -0,0 +1,32 @@
+For reproduce our paper,
+
+you should first download 4 zip file:
+
+`
+CVACT/satview_correct.zip , 
+CVACT/streetview.zip , 
+CVUSA/bingmap/19.zip ,
+CVUSA/streetview/panos.zip
+`
+ from [here](https://anu365-my.sharepoint.com/:f:/g/personal/u6293587_anu_edu_au/EuOBUDUQNClJvCpQ8bD1hnoBjdRBWxsHOVp946YVahiMGg?e=F4yRAC), the project page is [Sat2StrPanoramaSynthesis](https://github.com/shiyujiao/Sat2StrPanoramaSynthesis).
+
+Then download the sky mask from [here](https://drive.google.com/drive/folders/1pfzwONg4P-Mzvxvzb2HoCpuZFynElPCk?usp=sharing)
+
+Last,the users should organize the dataset just like:
+```
+├dataset
+├── CVACT
+│   ├── streetview
+│   ├── satview_correct
+│   ├── pano_sky_mask
+│   ├── ACT_data.mat
+└── CVUSA
+│   ├── bingmap
+│   │   ├── 19
+│   └── streetview
+│   │   ├── panos
+│   ├── sky_mask
+│   ├── splits
+```
+
+Tip: The sky masks are processed with [Trans4PASS](https://github.com/jamycheung/Trans4PASS).
diff --git a/demo_img/case1/groundview.image.png b/demo_img/case1/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..66f6050ec9e635a04e58242826af667f894adde0
Binary files /dev/null and b/demo_img/case1/groundview.image.png differ
diff --git a/demo_img/case1/groundview.sky.png b/demo_img/case1/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..413a82e27fd781030a1314513e859610ab48f652
Binary files /dev/null and b/demo_img/case1/groundview.sky.png differ
diff --git a/demo_img/case1/satview-input.png b/demo_img/case1/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..375575a45a1d1b33a87c0133e9237d0b89af9f3e
Binary files /dev/null and b/demo_img/case1/satview-input.png differ
diff --git a/demo_img/case10/groundview.image.png b/demo_img/case10/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..66f6050ec9e635a04e58242826af667f894adde0
Binary files /dev/null and b/demo_img/case10/groundview.image.png differ
diff --git a/demo_img/case10/groundview.sky.png b/demo_img/case10/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..5252275f364573f59f2f3b57a177327c297f523a
Binary files /dev/null and b/demo_img/case10/groundview.sky.png differ
diff --git a/demo_img/case10/satview-input.png b/demo_img/case10/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..efe5929d06fb9be3eeda5deb84733614ce571b80
Binary files /dev/null and b/demo_img/case10/satview-input.png differ
diff --git a/demo_img/case11/groundview.image.png b/demo_img/case11/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..a4c7ad7b5281dd57dc161173699a570592978f6a
Binary files /dev/null and b/demo_img/case11/groundview.image.png differ
diff --git a/demo_img/case11/groundview.sky.png b/demo_img/case11/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..31d697c17321a98ad5838420419a2c8028da20b4
Binary files /dev/null and b/demo_img/case11/groundview.sky.png differ
diff --git a/demo_img/case11/satview-input.png b/demo_img/case11/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..78c725dc36987de0b11c1f044294deeb86d7e20c
Binary files /dev/null and b/demo_img/case11/satview-input.png differ
diff --git a/demo_img/case12/groundview.image.png b/demo_img/case12/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e74053765745892dc5847955d14fc8210efcc94
Binary files /dev/null and b/demo_img/case12/groundview.image.png differ
diff --git a/demo_img/case12/groundview.sky.png b/demo_img/case12/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..70c7652da7c488dccc1393b3abf7f916e00662bb
Binary files /dev/null and b/demo_img/case12/groundview.sky.png differ
diff --git a/demo_img/case12/satview-input.png b/demo_img/case12/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..28858ab35e3c67162cf717e297867e44f545c2ab
Binary files /dev/null and b/demo_img/case12/satview-input.png differ
diff --git a/demo_img/case13/groundview.image.png b/demo_img/case13/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..9fd2f5da5ba5e7947074df32bb723d2631115cf5
Binary files /dev/null and b/demo_img/case13/groundview.image.png differ
diff --git a/demo_img/case13/groundview.sky.png b/demo_img/case13/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0c9f795448f2c9a4cb4253c4036a31f66356ae2
Binary files /dev/null and b/demo_img/case13/groundview.sky.png differ
diff --git a/demo_img/case13/satview-input.png b/demo_img/case13/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..53b78fd4a1252d9c5f730ce27bd485810ced4db1
Binary files /dev/null and b/demo_img/case13/satview-input.png differ
diff --git a/demo_img/case2/groundview.image.png b/demo_img/case2/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..9fd2f5da5ba5e7947074df32bb723d2631115cf5
Binary files /dev/null and b/demo_img/case2/groundview.image.png differ
diff --git a/demo_img/case2/groundview.sky.png b/demo_img/case2/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0c9f795448f2c9a4cb4253c4036a31f66356ae2
Binary files /dev/null and b/demo_img/case2/groundview.sky.png differ
diff --git a/demo_img/case2/satview-input.png b/demo_img/case2/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..53b78fd4a1252d9c5f730ce27bd485810ced4db1
Binary files /dev/null and b/demo_img/case2/satview-input.png differ
diff --git a/demo_img/case3/groundview.image.png b/demo_img/case3/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..ee07e66c811c49961b5b4cf83e2bde17c1c4ab96
Binary files /dev/null and b/demo_img/case3/groundview.image.png differ
diff --git a/demo_img/case3/groundview.sky.png b/demo_img/case3/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..bf1580c092c0e7b80f8199d845f71acc8bf23ee3
Binary files /dev/null and b/demo_img/case3/groundview.sky.png differ
diff --git a/demo_img/case3/satview-input.png b/demo_img/case3/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..d270d4f7ae9a135b0df74e35c7f5ba5440bd07cd
Binary files /dev/null and b/demo_img/case3/satview-input.png differ
diff --git a/demo_img/case4/groundview.image.png b/demo_img/case4/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..bdd86213b5efdf40e457a5fbead6d2592a40dc0d
Binary files /dev/null and b/demo_img/case4/groundview.image.png differ
diff --git a/demo_img/case4/groundview.sky.png b/demo_img/case4/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..996b7120fbb55c5987c9b58ba1ccf28d16747960
Binary files /dev/null and b/demo_img/case4/groundview.sky.png differ
diff --git a/demo_img/case4/satview-input.png b/demo_img/case4/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..e16179fe548be5115eb5906ae7408c56fdae62f8
Binary files /dev/null and b/demo_img/case4/satview-input.png differ
diff --git a/demo_img/case5/groundview.image.png b/demo_img/case5/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..a9266ebdaead7935d8ea0f834cf33e351da0dd5a
Binary files /dev/null and b/demo_img/case5/groundview.image.png differ
diff --git a/demo_img/case5/groundview.sky.png b/demo_img/case5/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..8598b9e878bcbc54d1b5eaa90e729d2dd77308b0
Binary files /dev/null and b/demo_img/case5/groundview.sky.png differ
diff --git a/demo_img/case5/satview-input.png b/demo_img/case5/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c12242298ef0f2721ec629e780dc75d3af8e810
Binary files /dev/null and b/demo_img/case5/satview-input.png differ
diff --git a/demo_img/case6/groundview.image.png b/demo_img/case6/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..71aab5b4aeca63f04b8d0f4eb8e29f5075612c48
Binary files /dev/null and b/demo_img/case6/groundview.image.png differ
diff --git a/demo_img/case6/groundview.sky.png b/demo_img/case6/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d00e485d6cf4ce46aba8274c527598c54e9b970
Binary files /dev/null and b/demo_img/case6/groundview.sky.png differ
diff --git a/demo_img/case6/satview-input.png b/demo_img/case6/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0837bce0790d1795f0079f8c0c76b53a3fa36b9
Binary files /dev/null and b/demo_img/case6/satview-input.png differ
diff --git a/demo_img/case7/groundview.image.png b/demo_img/case7/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..fcbfd8e23a6411d0cbbac43c15efcea4d6a6dd20
Binary files /dev/null and b/demo_img/case7/groundview.image.png differ
diff --git a/demo_img/case7/groundview.sky.png b/demo_img/case7/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..60ce35217b3faf8894d9cbd3e8ce58c77c29d28b
Binary files /dev/null and b/demo_img/case7/groundview.sky.png differ
diff --git a/demo_img/case7/satview-input.png b/demo_img/case7/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..708349179ddfdbeff7fae9ce516a5e49323bc2f8
Binary files /dev/null and b/demo_img/case7/satview-input.png differ
diff --git a/demo_img/case8/groundview.image.png b/demo_img/case8/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..18d52022a5becfb48728b8ea23679dffeee870ea
Binary files /dev/null and b/demo_img/case8/groundview.image.png differ
diff --git a/demo_img/case8/groundview.sky.png b/demo_img/case8/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..94731c15c751e9df883806ac0d43fbfedb270ab7
Binary files /dev/null and b/demo_img/case8/groundview.sky.png differ
diff --git a/demo_img/case8/satview-input.png b/demo_img/case8/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..981ce9f97d7d182ff441a3f925fb1d03598f9ae2
Binary files /dev/null and b/demo_img/case8/satview-input.png differ
diff --git a/demo_img/case9/groundview.image.png b/demo_img/case9/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..0084c60d4618d184db5e0b383403a20ce1644555
Binary files /dev/null and b/demo_img/case9/groundview.image.png differ
diff --git a/demo_img/case9/groundview.sky.png b/demo_img/case9/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..649db36fe97c73ea18f7712941de6c6b0f7bad19
Binary files /dev/null and b/demo_img/case9/groundview.sky.png differ
diff --git a/demo_img/case9/satview-input.png b/demo_img/case9/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..18d0484df8b32f566532c6345abb98eec6e9195e
Binary files /dev/null and b/demo_img/case9/satview-input.png differ
diff --git a/demo_img/runall.sh b/demo_img/runall.sh
new file mode 100644
index 0000000000000000000000000000000000000000..975924bab8dab46fb3e42c0454985aa93afbdb1e
--- /dev/null
+++ b/demo_img/runall.sh
@@ -0,0 +1,30 @@
+# for case in `ls -d demo_img/case*`
+for case_id in 1 2 3 4
+do
+    case=demo_img/case$case_id
+    echo $case
+    python test.py --yaml=sat2density_cvact \
+    --test_ckpt_path=2u87bj8w \
+    --task=test_vid \
+    --demo_img=$case/satview-input.png  \
+    --sty_img=$case/groundview.image.png  \
+    --save_dir=results/$case
+    # ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png results/$case/render.gif
+    ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -vf "palettegen" results/$case-palette.png
+    ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/render.gif
+
+    ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -vf "palettegen" results/$case-palette.png
+    ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/sat.gif
+    # ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png results/$case/sat.gif
+done
+
+# for case in `ls -d demo_img/case*`
+for case_id in 1 2 3 4
+do
+    case=demo_img/case$case_id
+    sat_gif=results/$case/sat.gif
+    render_gif=results/$case/render.gif
+    # echo $sat_gif
+    cp $sat_gif docs/figures/demo/case$case_id.sat.gif
+    cp $render_gif docs/figures/demo/case$case_id.render.gif
+done
\ No newline at end of file
diff --git a/imaginaire/__init__.py b/imaginaire/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/__pycache__/__init__.cpython-38.pyc b/imaginaire/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e39ab8b1c6c47a307883db29f98f32adce2f4ea6
Binary files /dev/null and b/imaginaire/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/config.py b/imaginaire/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a728a5aaee8d040288ff9ffd17a4fa83a7e2ca7
--- /dev/null
+++ b/imaginaire/config.py
@@ -0,0 +1,238 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""Config utilities for yml file."""
+
+import collections
+import functools
+import os
+import re
+
+import yaml
+from imaginaire.utils.distributed import master_only_print as print
+
+DEBUG = False
+USE_JIT = False
+
+
+class AttrDict(dict):
+    """Dict as attribute trick."""
+
+    def __init__(self, *args, **kwargs):
+        super(AttrDict, self).__init__(*args, **kwargs)
+        self.__dict__ = self
+        for key, value in self.__dict__.items():
+            if isinstance(value, dict):
+                self.__dict__[key] = AttrDict(value)
+            elif isinstance(value, (list, tuple)):
+                if isinstance(value[0], dict):
+                    self.__dict__[key] = [AttrDict(item) for item in value]
+                else:
+                    self.__dict__[key] = value
+
+    def yaml(self):
+        """Convert object to yaml dict and return."""
+        yaml_dict = {}
+        for key, value in self.__dict__.items():
+            if isinstance(value, AttrDict):
+                yaml_dict[key] = value.yaml()
+            elif isinstance(value, list):
+                if isinstance(value[0], AttrDict):
+                    new_l = []
+                    for item in value:
+                        new_l.append(item.yaml())
+                    yaml_dict[key] = new_l
+                else:
+                    yaml_dict[key] = value
+            else:
+                yaml_dict[key] = value
+        return yaml_dict
+
+    def __repr__(self):
+        """Print all variables."""
+        ret_str = []
+        for key, value in self.__dict__.items():
+            if isinstance(value, AttrDict):
+                ret_str.append('{}:'.format(key))
+                child_ret_str = value.__repr__().split('\n')
+                for item in child_ret_str:
+                    ret_str.append('    ' + item)
+            elif isinstance(value, list):
+                if isinstance(value[0], AttrDict):
+                    ret_str.append('{}:'.format(key))
+                    for item in value:
+                        # Treat as AttrDict above.
+                        child_ret_str = item.__repr__().split('\n')
+                        for item in child_ret_str:
+                            ret_str.append('    ' + item)
+                else:
+                    ret_str.append('{}: {}'.format(key, value))
+            else:
+                ret_str.append('{}: {}'.format(key, value))
+        return '\n'.join(ret_str)
+
+
+class Config(AttrDict):
+    r"""Configuration class. This should include every human specifiable
+    hyperparameter values for your training."""
+
+    def __init__(self, filename=None, verbose=False):
+        super(Config, self).__init__()
+        self.source_filename = filename
+        # Set default parameters.
+        # Logging.
+        large_number = 1000000000
+        self.snapshot_save_iter = large_number
+        self.snapshot_save_epoch = large_number
+        self.metrics_iter = None
+        self.metrics_epoch = None
+        self.snapshot_save_start_iter = 0
+        self.snapshot_save_start_epoch = 0
+        self.image_save_iter = large_number
+        self.image_display_iter = large_number
+        self.max_epoch = large_number
+        self.max_iter = large_number
+        self.logging_iter = 100
+        self.speed_benchmark = False
+
+        # Trainer.
+        self.trainer = AttrDict(
+            model_average_config=AttrDict(enabled=False,
+                                          beta=0.9999,
+                                          start_iteration=1000,
+                                          num_batch_norm_estimation_iterations=30,
+                                          remove_sn=True),
+            # model_average=False,
+            # model_average_beta=0.9999,
+            # model_average_start_iteration=1000,
+            # model_average_batch_norm_estimation_iteration=30,
+            # model_average_remove_sn=True,
+            image_to_tensorboard=False,
+            hparam_to_tensorboard=False,
+            distributed_data_parallel='pytorch',
+            distributed_data_parallel_params=AttrDict(
+                find_unused_parameters=False),
+            delay_allreduce=True,
+            gan_relativistic=False,
+            gen_step=1,
+            dis_step=1,
+            gan_decay_k=1.,
+            gan_min_k=1.,
+            gan_separate_topk=False,
+            aug_policy='',
+            channels_last=False,
+            strict_resume=True,
+            amp_gp=False,
+            amp_config=AttrDict(init_scale=65536.0,
+                                growth_factor=2.0,
+                                backoff_factor=0.5,
+                                growth_interval=2000,
+                                enabled=False))
+
+        # Networks.
+        self.gen = AttrDict(type='imaginaire.generators.dummy')
+        self.dis = AttrDict(type='imaginaire.discriminators.dummy')
+
+        # Optimizers.
+        self.gen_opt = AttrDict(type='adam',
+                                fused_opt=False,
+                                lr=0.0001,
+                                adam_beta1=0.0,
+                                adam_beta2=0.999,
+                                eps=1e-8,
+                                lr_policy=AttrDict(iteration_mode=False,
+                                                   type='step',
+                                                   step_size=large_number,
+                                                   gamma=1))
+        self.dis_opt = AttrDict(type='adam',
+                                fused_opt=False,
+                                lr=0.0001,
+                                adam_beta1=0.0,
+                                adam_beta2=0.999,
+                                eps=1e-8,
+                                lr_policy=AttrDict(iteration_mode=False,
+                                                   type='step',
+                                                   step_size=large_number,
+                                                   gamma=1))
+        # Data.
+        self.data = AttrDict(name='dummy',
+                             type='imaginaire.datasets.images',
+                             num_workers=0)
+        self.test_data = AttrDict(name='dummy',
+                                  type='imaginaire.datasets.images',
+                                  num_workers=0,
+                                  test=AttrDict(is_lmdb=False,
+                                                roots='',
+                                                batch_size=1))
+
+
+# Cudnn.
+        self.cudnn = AttrDict(deterministic=False,
+                              benchmark=True)
+
+        # Others.
+        self.pretrained_weight = ''
+        self.inference_args = AttrDict()
+
+        # Update with given configurations.
+        assert os.path.exists(filename), 'File {} not exist.'.format(filename)
+        loader = yaml.SafeLoader
+        loader.add_implicit_resolver(
+            u'tag:yaml.org,2002:float',
+            re.compile(u'''^(?:
+             [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+            |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+            |\\.[0-9_]+(?:[eE][-+][0-9]+)?
+            |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
+            |[-+]?\\.(?:inf|Inf|INF)
+            |\\.(?:nan|NaN|NAN))$''', re.X),
+            list(u'-+0123456789.'))
+        try:
+            with open(filename, 'r') as f:
+                cfg_dict = yaml.load(f, Loader=loader)
+        except EnvironmentError:
+            print('Please check the file with name of "%s"', filename)
+        recursive_update(self, cfg_dict)
+
+        # Put common opts in both gen and dis.
+        if 'common' in cfg_dict:
+            self.common = AttrDict(**cfg_dict['common'])
+            self.gen.common = self.common
+            self.dis.common = self.common
+
+        if verbose:
+            print(' imaginaire config '.center(80, '-'))
+            print(self.__repr__())
+            print(''.center(80, '-'))
+
+
+def rsetattr(obj, attr, val):
+    """Recursively find object and set value"""
+    pre, _, post = attr.rpartition('.')
+    return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def rgetattr(obj, attr, *args):
+    """Recursively find object and return value"""
+
+    def _getattr(obj, attr):
+        r"""Get attribute."""
+        return getattr(obj, attr, *args)
+
+    return functools.reduce(_getattr, [obj] + attr.split('.'))
+
+
+def recursive_update(d, u):
+    """Recursively update AttrDict d with AttrDict u"""
+    for key, value in u.items():
+        if isinstance(value, collections.abc.Mapping):
+            d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
+        elif isinstance(value, (list, tuple)):
+            if isinstance(value[0], dict):
+                d.__dict__[key] = [AttrDict(item) for item in value]
+            else:
+                d.__dict__[key] = value
+        else:
+            d.__dict__[key] = value
+    return d
diff --git a/imaginaire/datasets/__init__.py b/imaginaire/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/datasets/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/datasets/base.py b/imaginaire/datasets/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9192f78d4c3cfb65ac73751b632f897893d8288
--- /dev/null
+++ b/imaginaire/datasets/base.py
@@ -0,0 +1,596 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""All datasets are inherited from this class."""
+
+import importlib
+import json
+import os
+import pickle
+from collections import OrderedDict
+from functools import partial
+from inspect import signature
+
+import numpy as np
+import torch
+import torch.utils.data as data
+import torchvision.transforms as transforms
+
+from imaginaire.datasets.folder import FolderDataset
+from imaginaire.datasets.lmdb import \
+    IMG_EXTENSIONS, HDR_IMG_EXTENSIONS, LMDBDataset
+from imaginaire.datasets.object_store import ObjectStoreDataset
+from imaginaire.utils.data import \
+    (VIDEO_EXTENSIONS, Augmentor,
+     load_from_folder, load_from_lmdb, load_from_object_store)
+from imaginaire.utils.lmdb import create_metadata
+
+
+DATASET_TYPES = ['lmdb', 'folder', 'object_store']
+
+
+class BaseDataset(data.Dataset):
+    r"""Base class for image/video datasets.
+
+    Args:
+        cfg (Config object): Input config.
+        is_inference (bool): Training if False, else validation.
+        is_test (bool): Final test set after training and validation.
+    """
+
+    def __init__(self, cfg, is_inference, is_test):
+        super(BaseDataset, self).__init__()
+
+        self.cfg = cfg
+        self.is_inference = is_inference
+        self.is_test = is_test
+        if self.is_test:
+            self.cfgdata = self.cfg.test_data
+            data_info = self.cfgdata.test
+        else:
+            self.cfgdata = self.cfg.data
+            if self.is_inference:
+                data_info = self.cfgdata.val
+            else:
+                data_info = self.cfgdata.train
+        self.name = self.cfgdata.name
+        self.lmdb_roots = data_info.roots
+        self.dataset_type = getattr(data_info, 'dataset_type', None)
+        self.cache = getattr(self.cfgdata, 'cache', None)
+        self.interpolator = getattr(self.cfgdata, 'interpolator', "INTER_LINEAR")
+
+        # Get AWS secret keys.
+        if self.dataset_type == 'object_store':
+            assert hasattr(cfg, 'aws_credentials_file')
+            self.aws_credentials_file = cfg.aws_credentials_file
+
+        # Legacy lmdb/folder only support.
+        if self.dataset_type is None:
+            self.dataset_is_lmdb = getattr(data_info, 'is_lmdb', False)
+            if self.dataset_is_lmdb:
+                self.dataset_type = 'lmdb'
+            else:
+                self.dataset_type = 'folder'
+        # Legacy support ends.
+
+        assert self.dataset_type in DATASET_TYPES
+        if self.dataset_type == 'lmdb':
+            # Add handle to function to load data from LMDB.
+            self.load_from_dataset = load_from_lmdb
+        elif self.dataset_type == 'folder':
+            # For some unpaired experiments, we would like the dataset to be presented in a paired way
+
+            if hasattr(self.cfgdata, 'paired') is False:
+                self.cfgdata.paired = self.paired
+            # Add handle to function to load data from folder.
+            self.load_from_dataset = load_from_folder
+            # Create metadata for folders.
+            print('Creating metadata')
+            all_filenames, all_metadata = [], []
+            if self.is_test:
+                cfg.data_backup = cfg.data
+                cfg.data = cfg.test_data
+            for root in self.lmdb_roots:
+                filenames, metadata = create_metadata(
+                    data_root=root, cfg=cfg, paired=self.cfgdata['paired'])
+                all_filenames.append(filenames)
+                all_metadata.append(metadata)
+            if self.is_test:
+                cfg.data = cfg.data_backup
+        elif self.dataset_type == 'object_store':
+            # Add handle to function to load data from AWS S3.
+            self.load_from_dataset = load_from_object_store
+
+        # Get the types of data stored in dataset, and their extensions.
+        self.data_types = []  # Names of data types.
+        self.dataset_data_types = []  # These data types are in the dataset.
+        self.image_data_types = []  # These types are images.
+        self.hdr_image_data_types = []  # These types are HDR images.
+        self.normalize = {}  # Does this data type need normalization?
+        self.extensions = {}  # What is this data type's file extension.
+        self.is_mask = {}  # Whether this data type is discrete masks?
+        self.num_channels = {}  # How many channels does this data type have?
+        self.pre_aug_ops = {}  # Ops on data type before augmentation.
+        self.post_aug_ops = {}  # Ops on data type after augmentation.
+
+        # Extract info from data types.
+        for data_type in self.cfgdata.input_types:
+            name = list(data_type.keys())
+            assert len(name) == 1
+            name = name[0]
+            info = data_type[name]
+
+            if 'ext' not in info:
+                info['ext'] = None
+            if 'normalize' not in info:
+                info['normalize'] = False
+            if 'is_mask' not in info:
+                info['is_mask'] = False
+            if 'pre_aug_ops' not in info:
+                info['pre_aug_ops'] = 'None'
+            if 'post_aug_ops' not in info:
+                info['post_aug_ops'] = 'None'
+            if 'computed_on_the_fly' not in info:
+                info['computed_on_the_fly'] = False
+            if 'num_channels' not in info:
+                info['num_channels'] = None
+
+            self.data_types.append(name)
+            if not info['computed_on_the_fly']:
+                self.dataset_data_types.append(name)
+
+            self.extensions[name] = info['ext']
+            self.normalize[name] = info['normalize']
+            self.num_channels[name] = info['num_channels']
+            self.pre_aug_ops[name] = [op.strip() for op in
+                                      info['pre_aug_ops'].split(',')]
+            self.post_aug_ops[name] = [op.strip() for op in
+                                       info['post_aug_ops'].split(',')]
+            self.is_mask[name] = info['is_mask']
+            if info['ext'] is not None and (info['ext'] in IMG_EXTENSIONS or info['ext'] in VIDEO_EXTENSIONS):
+                self.image_data_types.append(name)
+            if info['ext'] is not None and info['ext'] in HDR_IMG_EXTENSIONS:
+                self.hdr_image_data_types.append(name)
+
+        # Add some info into cfgdata for legacy support.
+        self.cfgdata.data_types = self.data_types
+        self.cfgdata.num_channels = [self.num_channels[name]
+                                     for name in self.data_types]
+
+        # Augmentations which need full dict.
+        self.full_data_post_aug_ops, self.full_data_ops = [], []
+        if hasattr(self.cfgdata, 'full_data_ops'):
+            ops = self.cfgdata.full_data_ops
+            self.full_data_ops.extend([op.strip() for op in ops.split(',')])
+        if hasattr(self.cfgdata, 'full_data_post_aug_ops'):
+            ops = self.cfgdata.full_data_post_aug_ops
+            self.full_data_post_aug_ops.extend(
+                [op.strip() for op in ops.split(',')])
+
+        # These are the labels which will be concatenated for generator input.
+        self.input_labels = []
+        if hasattr(self.cfgdata, 'input_labels'):
+            self.input_labels = self.cfgdata.input_labels
+
+        # These are the keypoints which also need to be augmented.
+        self.keypoint_data_types = []
+        if hasattr(self.cfgdata, 'keypoint_data_types'):
+            self.keypoint_data_types = self.cfgdata.keypoint_data_types
+
+        # Create augmentation operations.
+        aug_list = data_info.augmentations
+        individual_video_frame_aug_list = getattr(data_info, 'individual_video_frame_augmentations', dict())
+        self.augmentor = Augmentor(
+            aug_list, individual_video_frame_aug_list, self.image_data_types, self.is_mask,
+            self.keypoint_data_types, self.interpolator)
+        self.augmentable_types = self.image_data_types + \
+            self.keypoint_data_types
+
+        # Create torch transformations.
+        self.transform = {}
+        for data_type in self.image_data_types:
+            normalize = self.normalize[data_type]
+            self.transform[data_type] = self._get_transform(
+                normalize, self.num_channels[data_type])
+
+        # Create torch transformations for HDR images.
+        for data_type in self.hdr_image_data_types:
+            normalize = self.normalize[data_type]
+            self.transform[data_type] = self._get_transform(
+                normalize, self.num_channels[data_type])
+
+        # Initialize handles.
+        self.sequence_lists = []  # List of sequences per dataset root.
+        self.lmdbs = {}  # Dict for list of lmdb handles per data type.
+        for data_type in self.dataset_data_types:
+            self.lmdbs[data_type] = []
+        self.dataset_probability = None
+        self.additional_lists = []
+
+        # Load each dataset.
+        for idx, root in enumerate(self.lmdb_roots):
+            if self.dataset_type == 'lmdb':
+                self._add_dataset(root)
+            elif self.dataset_type == 'folder':
+                self._add_dataset(root, filenames=all_filenames[idx],
+                                  metadata=all_metadata[idx])
+            elif self.dataset_type == 'object_store':
+                self._add_dataset(
+                    root, aws_credentials_file=self.aws_credentials_file)
+
+        # Compute dataset statistics and create whatever self.variables required
+        # for the specific dataloader.
+        self._compute_dataset_stats()
+
+        # Build index of data to sample.
+        self.mapping, self.epoch_length = self._create_mapping()
+
+    def _create_mapping(self):
+        r"""Creates mapping from data sample idx to actual LMDB keys.
+            All children need to implement their own.
+
+        Returns:
+            self.mapping (list): List of LMDB keys.
+        """
+        raise NotImplementedError
+
+    def _compute_dataset_stats(self):
+        r"""Computes required statistics about dataset.
+           All children need to implement their own.
+        """
+        pass
+
+    def __getitem__(self, index):
+        r"""Entry function for dataset."""
+        raise NotImplementedError
+
+    def _get_transform(self, normalize, num_channels):
+        r"""Convert numpy to torch tensor.
+
+        Args:
+            normalize (bool): Normalize image i.e. (x - 0.5) * 2.
+                Goes from [0, 1] -> [-1, 1].
+        Returns:
+            Composed list of torch transforms.
+        """
+        transform_list = [transforms.ToTensor()]
+        if normalize:
+            transform_list.append(
+                transforms.Normalize((0.5, ) * num_channels,
+                                     (0.5, ) * num_channels, inplace=True))
+        return transforms.Compose(transform_list)
+
+    def _add_dataset(self, root, filenames=None, metadata=None,
+                     aws_credentials_file=None):
+        r"""Adds an LMDB dataset to a list of datasets.
+
+        Args:
+            root (str): Path to LMDB or folder dataset.
+            filenames: List of filenames for folder dataset.
+            metadata: Metadata for folder dataset.
+            aws_credentials_file: Path to file containing AWS credentials.
+        """
+        if aws_credentials_file and self.dataset_type == 'object_store':
+            object_store_dataset = ObjectStoreDataset(
+                root, aws_credentials_file, cache=self.cache)
+            sequence_list = object_store_dataset.sequence_list
+        else:
+            # Get sequences associated with this dataset.
+            if filenames is None:
+                list_path = 'all_filenames.json'
+                with open(os.path.join(root, list_path)) as fin:
+                    sequence_list = OrderedDict(json.load(fin))
+            else:
+                sequence_list = filenames
+
+            additional_path = 'all_indices.json'
+            if os.path.exists(os.path.join(root, additional_path)):
+                print('Using additional list for object indices.')
+                with open(os.path.join(root, additional_path)) as fin:
+                    additional_list = OrderedDict(json.load(fin))
+                self.additional_lists.append(additional_list)
+        self.sequence_lists.append(sequence_list)
+
+        # Get LMDB dataset handles.
+        for data_type in self.dataset_data_types:
+            if self.dataset_type == 'lmdb':
+                self.lmdbs[data_type].append(
+                    LMDBDataset(os.path.join(root, data_type)))
+            elif self.dataset_type == 'folder':
+                self.lmdbs[data_type].append(
+                    FolderDataset(os.path.join(root, data_type), metadata))
+            elif self.dataset_type == 'object_store':
+                # All data types use the same handle.
+                self.lmdbs[data_type].append(object_store_dataset)
+
+    def perform_individual_video_frame(self, data, augment_ops):
+        r"""Perform data augmentation on images only.
+
+        Args:
+            data (dict): Keys are from data types. Values can be numpy.ndarray
+                or list of numpy.ndarray (image or list of images).
+            augment_ops (list): The augmentation operations for individual frames.
+        Returns:
+            (tuple):
+              - data (dict): Augmented data, with same keys as input data.
+              - is_flipped (bool): Flag which tells if images have been
+                left-right flipped.
+        """
+        if augment_ops:
+            all_data = dict()
+            for ix, key in enumerate(data.keys()):
+                if ix == 0:
+                    num = len(data[key])
+                    for j in range(num):
+                        all_data['%d' % j] = dict()
+                for j in range(num):
+                    all_data['%d' % j][key] = data[key][j:(j+1)]
+            for j in range(num):
+                all_data['%d' % j], _ = self.perform_augmentation(
+                    all_data['%d' % j], paired=True, augment_ops=augment_ops)
+            for key in data.keys():
+                tmp = []
+                for j in range(num):
+                    tmp += all_data['%d' % j][key]
+                data[key] = tmp
+        return data
+
+    def perform_augmentation(self, data, paired, augment_ops=None):
+        r"""Perform data augmentation on images only.
+
+        Args:
+            data (dict): Keys are from data types. Values can be numpy.ndarray
+                or list of numpy.ndarray (image or list of images).
+            paired (bool): Apply same augmentation to all input keys?
+            augment_ops (list): The augmentation operations.
+        Returns:
+            (tuple):
+              - data (dict): Augmented data, with same keys as input data.
+              - is_flipped (bool): Flag which tells if images have been
+                left-right flipped.
+        """
+        aug_inputs = {}
+        for data_type in self.augmentable_types:
+            aug_inputs[data_type] = data[data_type]
+
+        augmented, is_flipped = self.augmentor.perform_augmentation(
+            aug_inputs, paired=paired, augment_ops=augment_ops)
+
+        for data_type in self.augmentable_types:
+            data[data_type] = augmented[data_type]
+
+        return data, is_flipped
+
+    def flip_hdr(self, data, is_flipped=False):
+        r"""Flip hdr images.
+
+        Args:
+            data (dict): Keys are from data types. Values can be numpy.ndarray
+                or list of numpy.ndarray (image or list of images).
+            is_flipped (bool): Applying left-right flip to the hdr images
+        Returns:
+            (tuple):
+              - data (dict): Augmented data, with same keys as input data.
+        """
+        if is_flipped is False:
+            return data
+
+        for data_type in self.hdr_image_data_types:
+            # print('Length of data: {}'.format(len(data[data_type])))
+            data[data_type][0] = data[data_type][0][:, ::-1, :].copy()
+        return data
+
+    def to_tensor(self, data):
+        r"""Convert all images to tensor.
+
+        Args:
+            data (dict): Dict containing data_type as key, with each value
+                as a list of numpy.ndarrays.
+        Returns:
+            data (dict): Dict containing data_type as key, with each value
+            as a list of torch.Tensors.
+        """
+        for data_type in self.image_data_types:
+            for idx in range(len(data[data_type])):
+                if data[data_type][idx].dtype == np.uint16:
+                    data[data_type][idx] = data[data_type][idx].astype(
+                        np.float32)
+                data[data_type][idx] = self.transform[data_type](
+                    data[data_type][idx])
+        for data_type in self.hdr_image_data_types:
+            for idx in range(len(data[data_type])):
+                data[data_type][idx] = self.transform[data_type](
+                    data[data_type][idx])
+        return data
+
+    def apply_ops(self, data, op_dict, full_data=False):
+        r"""Apply any ops from op_dict to data types.
+
+        Args:
+            data (dict): Dict containing data_type as key, with each value
+                as a list of numpy.ndarrays.
+            op_dict (dict): Dict containing data_type as key, with each value
+                containing string of operations to apply.
+            full_data (bool): Do these ops require access to the full data?
+        Returns:
+            data (dict): Dict containing data_type as key, with each value
+            modified by the op if any.
+        """
+        if full_data:
+            # op needs entire data dict.
+            for op in op_dict:
+                if op == 'None':
+                    continue
+                op, op_type = self.get_op(op)
+                assert op_type == 'full_data'
+                data = op(data)
+        else:
+            # op per data type.
+            if not op_dict:
+                return data
+            for data_type in data:
+                for op in op_dict[data_type]:
+                    if op == 'None':
+                        continue
+                    op, op_type = self.get_op(op)
+                    data[data_type] = op(data[data_type])
+
+                    if op_type == 'vis':
+                        # We have converted this data type to an image. Enter it
+                        # in self.image_data_types and give it a torch
+                        # transform.
+                        if data_type not in self.image_data_types:
+                            self.image_data_types.append(data_type)
+                            normalize = self.normalize[data_type]
+                            num_channels = self.num_channels[data_type]
+                            self.transform[data_type] = \
+                                self._get_transform(normalize, num_channels)
+                    elif op_type == 'convert':
+                        continue
+                    elif op_type is None:
+                        continue
+                    else:
+                        raise NotImplementedError
+        return data
+
+    def get_op(self, op):
+        r"""Get function to apply for specific op.
+
+        Args:
+            op (str): Name of the op.
+        Returns:
+            function handle.
+        """
+        def list_to_tensor(data):
+            r"""Convert list of numeric values to tensor."""
+            assert isinstance(data, list)
+            return torch.from_numpy(np.array(data, dtype=np.float32))
+
+        def decode_json_list(data):
+            r"""Decode list of strings in json to objects."""
+            assert isinstance(data, list)
+            return [json.loads(item) for item in data]
+
+        def decode_pkl_list(data):
+            r"""Decode list of pickled strings to objects."""
+            assert isinstance(data, list)
+            return [pickle.loads(item) for item in data]
+
+        def list_to_numpy(data):
+            r"""Convert list of numeric values to numpy array."""
+            assert isinstance(data, list)
+            return np.array(data)
+
+        def l2_normalize(data):
+            r"""L2 normalization."""
+            assert isinstance(data, torch.Tensor)
+            import torch.nn.functional as F
+            return F.normalize(data, dim=1)
+
+        if op == 'to_tensor':
+            return list_to_tensor, None
+        elif op == 'decode_json':
+            return decode_json_list, None
+        elif op == 'decode_pkl':
+            return decode_pkl_list, None
+        elif op == 'to_numpy':
+            return list_to_numpy, None
+        elif op == 'l2_norm':
+            return l2_normalize, None
+        elif '::' in op:
+            parts = op.split('::')
+            if len(parts) == 2:
+                module, function = parts
+                module = importlib.import_module(module)
+                function = getattr(module, function)
+                sig = signature(function)
+                num_params = len(sig.parameters)
+                assert num_params in [3, 4], \
+                    'Full data functions take in (cfgdata, is_inference, ' \
+                    'full_data) or (cfgdata, is_inference, self, full_data) ' \
+                    'as input.'
+                if num_params == 3:
+                    function = partial(
+                        function, self.cfgdata, self.is_inference)
+                elif num_params == 4:
+                    function = partial(
+                        function, self.cfgdata, self.is_inference, self)
+                function_type = 'full_data'
+            elif len(parts) == 3:
+                function_type, module, function = parts
+                module = importlib.import_module(module)
+
+                # Get function inputs, if provided.
+                partial_fn = False
+                if '(' in function and ')' in function:
+                    partial_fn = True
+                    function, params = self._get_fn_params(function)
+
+                function = getattr(module, function)
+
+                # Create partial function.
+                if partial_fn:
+                    function = partial(function, **params)
+
+                # Get function signature.
+                sig = signature(function)
+                num_params = 0
+                for param in sig.parameters.values():
+                    if param.kind == param.POSITIONAL_OR_KEYWORD:
+                        num_params += 1
+
+                if function_type == 'vis':
+                    if num_params != 9:
+                        raise ValueError(
+                            'vis function type needs to take ' +
+                            '(resize_h, resize_w, crop_h, crop_w, ' +
+                            'original_h, original_w, is_flipped, cfgdata, ' +
+                            'data) as input.')
+                    function = partial(function,
+                                       self.augmentor.resize_h,
+                                       self.augmentor.resize_w,
+                                       self.augmentor.crop_h,
+                                       self.augmentor.crop_w,
+                                       self.augmentor.original_h,
+                                       self.augmentor.original_w,
+                                       self.augmentor.is_flipped,
+                                       self.cfgdata)
+                elif function_type == 'convert':
+                    if num_params != 1:
+                        raise ValueError(
+                            'convert function type needs to take ' +
+                            '(data) as input.')
+                else:
+                    raise ValueError('Unknown op: %s' % (op))
+            else:
+                raise ValueError('Unknown op: %s' % (op))
+            return function, function_type
+        else:
+            raise ValueError('Unknown op: %s' % (op))
+
+    def _get_fn_params(self, function_string):
+        r"""Find key-value inputs to function from string definition.
+
+        Args:
+            function_string (str): String with function name and args. e.g.
+            my_function(a=10, b=20).
+        Returns:
+            function (str): Name of function.
+            params (dict): Key-value params for function.
+        """
+        start = function_string.find('(')
+        end = function_string.find(')')
+        function = function_string[:start]
+        params_str = function_string[start+1:end]
+        params = {}
+        for item in params_str.split(':'):
+            key, value = item.split('=')
+            try:
+                params[key] = float(value)
+            except:  # noqa
+                params[key] = value
+        return function, params
+
+    def __len__(self):
+        return self.epoch_length
diff --git a/imaginaire/datasets/cache.py b/imaginaire/datasets/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c28752dc290c9cbf15ffcf6ca2093415082f93e
--- /dev/null
+++ b/imaginaire/datasets/cache.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import diskcache
+
+"""
+INFO:
+Cache objects are thread-safe and may be shared between threads.
+Two Cache objects may also reference the same directory from separate
+threads or processes. In this way, they are also process-safe and support
+cross-process communication.
+"""
+
+
+class Cache(object):
+    r"""This creates an on disk cache, which saves files as bytes.
+    Args:
+        root (str): Path to the cache dir.
+        size_MB (float): Size of cache in MB.
+    """
+
+    def __init__(self, root, size_GB):
+        self.root = root
+        self.size_limit_B = size_GB * 1024 * 1024 * 1024
+        self.cache = diskcache.Cache(root, size_limit=self.size_limit_B)
+        print('Created cache of max size %d GB at %s' %
+              (size_GB, self.cache.directory))
+
+    def read(self, key):
+        if key in self.cache:
+            return self.cache[key]
+        return False
+
+    def write(self, key, value):
+        try:
+            self.cache[key] = value
+        except Exception as e:  # noqa
+            print(e)
+            return False
diff --git a/imaginaire/datasets/dummy.py b/imaginaire/datasets/dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..9783eb3f38007652b84dde33f2da9202491686e1
--- /dev/null
+++ b/imaginaire/datasets/dummy.py
@@ -0,0 +1,18 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+
+
+class Dataset(torch.utils.data.Dataset):
+    r"""Dummy dataset, returns nothing."""
+
+    def __init__(self, cfg, is_inference=False, is_test=False):
+        super(Dataset, self).__init__()
+
+    def __getitem__(self, index):
+        return {}
+
+    def __len__(self):
+        return 65535
diff --git a/imaginaire/datasets/folder.py b/imaginaire/datasets/folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd3fcc679834044ac9bf11afe81a9e9fe8697aa8
--- /dev/null
+++ b/imaginaire/datasets/folder.py
@@ -0,0 +1,86 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import cv2
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+
+from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
+import imageio
+
+
+class FolderDataset(data.Dataset):
+    r"""This deals with opening, and reading from an Folder dataset.
+
+    Args:
+        root (str): Path to the folder.
+        metadata (dict): Containing extensions.
+    """
+
+    def __init__(self, root, metadata):
+        self.root = os.path.expanduser(root)
+        self.extensions = metadata
+
+        print('Folder at %s opened.' % (root))
+
+    def getitem_by_path(self, path, data_type):
+        r"""Load data item stored for key = path.
+
+        Args:
+            path (str): Key into Folder dataset.
+            data_type (str): Key into self.extensions e.g. data/data_segmaps/...
+        Returns:
+            img (PIL.Image) or buf (str): Contents of file for this key.
+        """
+        # Figure out decoding params.
+        ext = self.extensions[data_type]
+        is_image = False
+        is_hdr = False
+        if ext in IMG_EXTENSIONS:
+            is_image = True
+            if 'tif' in ext:
+                dtype, mode = np.uint16, -1
+            elif 'JPEG' in ext or 'JPG' in ext \
+                    or 'jpeg' in ext or 'jpg' in ext:
+                dtype, mode = np.uint8, 3
+            else:
+                dtype, mode = np.uint8, -1
+        elif ext in HDR_IMG_EXTENSIONS:
+            is_hdr = True
+        else:
+            is_image = False
+
+        # Get value from key.
+        filepath = os.path.join(self.root, path.decode() + '.' + ext)
+        assert os.path.exists(filepath), '%s does not exist' % (filepath)
+        with open(filepath, 'rb') as f:
+            buf = f.read()
+
+        # Decode and return.
+        if is_image:
+            try:
+                img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode)
+            except Exception:
+                print(path)
+            # BGR to RGB if 3 channels.
+            if img.ndim == 3 and img.shape[-1] == 3:
+                img = img[:, :, ::-1]
+            img = Image.fromarray(img)
+            return img
+        elif is_hdr:
+            try:
+                imageio.plugins.freeimage.download()
+                img = imageio.imread(buf)
+            except Exception:
+                print(path)
+            return img  # Return a numpy array
+        else:
+            return buf
+
+    def __len__(self):
+        r"""Return number of keys in Folder dataset."""
+        return self.length
diff --git a/imaginaire/datasets/images.py b/imaginaire/datasets/images.py
new file mode 100644
index 0000000000000000000000000000000000000000..943752be11d823025f33956dcdcaf6a35c1fb899
--- /dev/null
+++ b/imaginaire/datasets/images.py
@@ -0,0 +1,168 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import random
+
+from imaginaire.datasets.base import BaseDataset
+
+
+class Dataset(BaseDataset):
+    r"""Image dataset for use in class conditional GAN.
+
+    Args:
+        cfg (Config): Loaded config object.
+        is_inference (bool): In train or inference mode?
+    """
+
+    def __init__(self, cfg, is_inference=False, is_test=False):
+        self.paired = False
+        super(Dataset, self).__init__(cfg, is_inference, is_test)
+        self.num_classes = len(self.class_name_to_idx['images'])
+        self.sample_class_idx = None
+
+    def set_sample_class_idx(self, class_idx):
+        r"""Set sample class idx. This is not used in this class...
+
+        Args:
+            class_idx (int): Which class idx to sample from.
+        """
+        self.sample_class_idx = class_idx
+        self.epoch_length = \
+            max([len(lmdb_keys) for _, lmdb_keys in self.mapping.items()])
+
+    def _create_mapping(self):
+        r"""Creates mapping from idx to key in LMDB.
+
+        Returns:
+            (tuple):
+              - self.mapping (dict): Dict with data type as key mapping idx to
+                LMDB key.
+              - self.epoch_length (int): Number of samples in an epoch.
+        """
+        idx_to_key, class_names = {}, {}
+        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+            for data_type, data_type_sequence_list in sequence_list.items():
+                class_names[data_type] = []
+                if data_type not in idx_to_key:
+                    idx_to_key[data_type] = []
+                for sequence_name, filenames in data_type_sequence_list.items():
+                    class_name = sequence_name.split('/')[0]
+                    for filename in filenames:
+                        idx_to_key[data_type].append({
+                            'lmdb_root': self.lmdb_roots[lmdb_idx],
+                            'lmdb_idx': lmdb_idx,
+                            'sequence_name': sequence_name,
+                            'filename': filename,
+                            'class_name': class_name
+                        })
+                    class_names[data_type].append(class_name)
+        self.mapping = idx_to_key
+        self.epoch_length = max([len(lmdb_keys)
+                                 for _, lmdb_keys in self.mapping.items()])
+
+        # Create mapping from class name to class idx.
+        self.class_name_to_idx = {}
+        for data_type, class_names_data_type in class_names.items():
+            self.class_name_to_idx[data_type] = {}
+            class_names_data_type = sorted(list(set(class_names_data_type)))
+            for class_idx, class_name in enumerate(class_names_data_type):
+                self.class_name_to_idx[data_type][class_name] = class_idx
+
+        # Add class idx to mapping.
+        for data_type in self.mapping:
+            for key in self.mapping[data_type]:
+                key['class_idx'] = \
+                    self.class_name_to_idx[data_type][key['class_name']]
+
+        # Create a mapping from index to lmdb key for each class.
+        idx_to_key_class = {}
+        for data_type in self.mapping:
+            idx_to_key_class[data_type] = {}
+            for class_idx, class_name in enumerate(class_names[data_type]):
+                idx_to_key_class[data_type][class_idx] = []
+            for key in self.mapping[data_type]:
+                idx_to_key_class[data_type][key['class_idx']].append(key)
+        self.mapping_class = idx_to_key_class
+
+        return self.mapping, self.epoch_length
+
+    def _sample_keys(self, index):
+        r"""Gets files to load for this sample.
+
+        Args:
+            index (int): Index in [0, len(dataset)].
+        Returns:
+            keys (dict): Each key of this dict is a data type.
+              - lmdb_key (dict):
+                  - lmdb_idx (int): Chosen LMDB dataset root.
+                  - sequence_name (str): Chosen sequence in chosen dataset.
+                  - filename (str): Chosen filename in chosen sequence.
+        """
+
+        keys = {}
+        if self.is_inference:  # evaluation mode
+            lmdb_keys = self.mapping['images']
+            keys['images'] = lmdb_keys[index % len(lmdb_keys)]
+        else:
+            lmdb_keys = self.mapping['images']
+            keys['images'] = random.choice(lmdb_keys)
+        return keys
+
+    def __getitem__(self, index):
+        r"""Gets selected files.
+
+        Args:
+            index (int): Index into dataset.
+            concat (bool): Concatenate all items in labels?
+        Returns:
+            data (dict): Dict with all chosen data_types.
+        """
+        # Select a sample from the available data.
+        keys_per_data_type = self._sample_keys(index)
+
+        # Get class idx into a list.
+        class_idxs = []
+        for data_type in keys_per_data_type:
+            class_idxs.append(keys_per_data_type[data_type]['class_idx'])
+
+        # Get keys and lmdbs.
+        keys, lmdbs = {}, {}
+        for data_type in self.dataset_data_types:
+            # Unpack keys.
+            lmdb_idx = keys_per_data_type[data_type]['lmdb_idx']
+            sequence_name = keys_per_data_type[data_type]['sequence_name']
+            filename = keys_per_data_type[data_type]['filename']
+            keys[data_type] = '%s/%s' % (sequence_name, filename)
+            lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+        # Load all data for this index.
+        data = self.load_from_dataset(keys, lmdbs)
+
+        # Apply ops pre augmentation.
+        data = self.apply_ops(data, self.pre_aug_ops)
+
+        # Do augmentations for images.
+        data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops)
+
+        # Apply ops post augmentation.
+        data = self.apply_ops(data, self.post_aug_ops)
+        data = self.apply_ops(data, self.full_data_post_aug_ops,
+                              full_data=True)
+
+        # Convert images to tensor.
+        for data_type in self.image_data_types:
+            for idx in range(len(data[data_type])):
+                data[data_type][idx] = \
+                    data[data_type][idx][:, :, :self.num_channels[data_type]]
+        data = self.to_tensor(data)
+
+        # Remove any extra dimensions.
+        for data_type in self.image_data_types:
+            data[data_type] = data[data_type][0]
+
+        # Package output.
+        data['is_flipped'] = is_flipped
+        data['key'] = keys_per_data_type
+        data['labels'] = class_idxs[0]
+        return data
diff --git a/imaginaire/datasets/lmdb.py b/imaginaire/datasets/lmdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..136642c1e624b886b05aaffe46f14694b0eaa29a
--- /dev/null
+++ b/imaginaire/datasets/lmdb.py
@@ -0,0 +1,92 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import json
+import os
+
+import cv2
+import lmdb
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+
+from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
+from imaginaire.utils.distributed import master_only_print as print
+import imageio
+
+
+class LMDBDataset(data.Dataset):
+    r"""This deals with opening, and reading from an LMDB dataset.
+    Args:
+        root (str): Path to the LMDB file.
+    """
+
+    def __init__(self, root):
+        self.root = os.path.expanduser(root)
+        self.env = lmdb.open(root, max_readers=126, readonly=True, lock=False,
+                             readahead=False, meminit=False)
+        with self.env.begin(write=False) as txn:
+            self.length = txn.stat()['entries']
+
+        # Read metadata.
+        with open(os.path.join(self.root, '..', 'metadata.json')) as fin:
+            self.extensions = json.load(fin)
+
+        print('LMDB file at %s opened.' % (root))
+
+    def getitem_by_path(self, path, data_type):
+        r"""Load data item stored for key = path.
+
+        Args:
+            path (str): Key into LMDB dataset.
+            data_type (str): Key into self.extensions e.g. data/data_segmaps/...
+        Returns:
+            img (PIL.Image) or buf (str): Contents of LMDB value for this key.
+        """
+        # Figure out decoding params.
+        ext = self.extensions[data_type]
+        is_image = False
+        is_hdr = False
+        if ext in IMG_EXTENSIONS:
+            is_image = True
+            if 'tif' in ext:
+                dtype, mode = np.uint16, -1
+            elif 'JPEG' in ext or 'JPG' in ext \
+                    or 'jpeg' in ext or 'jpg' in ext:
+                dtype, mode = np.uint8, 3
+            else:
+                dtype, mode = np.uint8, -1
+        elif ext in HDR_IMG_EXTENSIONS:
+            is_hdr = True
+        else:
+            is_image = False
+
+        # Get value from key.
+        with self.env.begin(write=False) as txn:
+            buf = txn.get(path)
+
+        # Decode and return.
+        if is_image:
+            try:
+                img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode)
+            except Exception:
+                print(path)
+            # BGR to RGB if 3 channels.
+            if img.ndim == 3 and img.shape[-1] == 3:
+                img = img[:, :, ::-1]
+            img = Image.fromarray(img)
+            return img
+        elif is_hdr:
+            try:
+                imageio.plugins.freeimage.download()
+                img = imageio.imread(buf)
+            except Exception:
+                print(path)
+            return img  # Return a numpy array
+        else:
+            return buf
+
+    def __len__(self):
+        r"""Return number of keys in LMDB dataset."""
+        return self.length
diff --git a/imaginaire/datasets/object_store.py b/imaginaire/datasets/object_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dd4f2d765def17ed52d5a442cc3bb87d31b9ded
--- /dev/null
+++ b/imaginaire/datasets/object_store.py
@@ -0,0 +1,142 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import io
+import json
+
+# import cv2
+import boto3
+from botocore.config import Config
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+import imageio
+from botocore.exceptions import ClientError
+
+from imaginaire.datasets.cache import Cache
+from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
+
+Image.MAX_IMAGE_PIXELS = None
+
+
+class ObjectStoreDataset(data.Dataset):
+    r"""This deals with opening, and reading from an AWS S3 bucket.
+    Args:
+
+        root (str): Path to the AWS S3 bucket.
+        aws_credentials_file (str): Path to file containing AWS credentials.
+        data_type (str): Which data type should this dataset load?
+    """
+
+    def __init__(self, root, aws_credentials_file, data_type='', cache=None):
+        # Cache.
+        self.cache = False
+        if cache is not None:
+            # raise NotImplementedError
+            self.cache = Cache(cache.root, cache.size_GB)
+
+        # Get bucket info, and keys to info about dataset.
+        with open(aws_credentials_file) as fin:
+            self.credentials = json.load(fin)
+
+        parts = root.split('/')
+        self.bucket = parts[0]
+        self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json'
+        self.metadata_key = '/'.join(parts[1:]) + '/metadata.json'
+
+        # Get list of filenames.
+        filename_info = self._get_object(self.all_filenames_key)
+        self.sequence_list = json.loads(filename_info.decode('utf-8'))
+
+        # Get length.
+        length = 0
+        for _, value in self.sequence_list.items():
+            length += len(value)
+        self.length = length
+
+        # Read metadata.
+        metadata_info = self._get_object(self.metadata_key)
+        self.extensions = json.loads(metadata_info.decode('utf-8'))
+        self.data_type = data_type
+
+        print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type))
+
+    def _get_object(self, key):
+        r"""Download object from bucket.
+
+        Args:
+            key (str): Key inside bucket.
+        """
+        # Look up value in cache.
+        object_content = self.cache.read(key) if self.cache else False
+        if not object_content:
+            # Either no cache used or key not found in cache.
+            config = Config(connect_timeout=30,
+                            signature_version="s3",
+                            retries={"max_attempts": 999999})
+            s3 = boto3.client('s3', **self.credentials, config=config)
+            try:
+                s3_response_object = s3.get_object(Bucket=self.bucket, Key=key)
+                object_content = s3_response_object['Body'].read()
+            except Exception as e:
+                print('%s not found' % (key))
+                print(e)
+            # Save content to cache.
+            if self.cache:
+                self.cache.write(key, object_content)
+        return object_content
+
+    def getitem_by_path(self, path, data_type):
+        r"""Load data item stored for key = path.
+
+        Args:
+            path (str): Path into AWS S3 bucket, without data_type prefix.
+            data_type (str): Key into self.extensions e.g. data/data_segmaps/...
+        Returns:
+            img (PIL.Image) or buf (str): Contents of LMDB value for this key.
+        """
+        # Figure out decoding params.
+        ext = self.extensions[data_type]
+        is_image = False
+        is_hdr = False
+        parts = path.split('/')
+        key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext
+        if ext in IMG_EXTENSIONS:
+            is_image = True
+            if 'tif' in ext:
+                _, mode = np.uint16, -1
+            elif 'JPEG' in ext or 'JPG' in ext \
+                    or 'jpeg' in ext or 'jpg' in ext:
+                _, mode = np.uint8, 3
+            else:
+                _, mode = np.uint8, -1
+        elif ext in HDR_IMG_EXTENSIONS:
+            is_hdr = True
+        else:
+            is_image = False
+
+        # Get value from key.
+        buf = self._get_object(key)
+
+        # Decode and return.
+        if is_image:
+            # This is totally a hack.
+            # We should have a better way to handle grayscale images.
+            img = Image.open(io.BytesIO(buf))
+            if mode == 3:
+                img = img.convert('RGB')
+            return img
+        elif is_hdr:
+            try:
+                imageio.plugins.freeimage.download()
+                img = imageio.imread(buf)
+            except Exception:
+                print(path)
+            return img  # Return a numpy array
+        else:
+            return buf
+
+    def __len__(self):
+        r"""Return number of keys in LMDB dataset."""
+        return self.length
diff --git a/imaginaire/datasets/paired_few_shot_videos.py b/imaginaire/datasets/paired_few_shot_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b311bc36fbc05aceaf552a47d3d89b118728be
--- /dev/null
+++ b/imaginaire/datasets/paired_few_shot_videos.py
@@ -0,0 +1,308 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import copy
+import random
+import torch
+
+from imaginaire.datasets.paired_videos import Dataset as VideoDataset
+from imaginaire.model_utils.fs_vid2vid import select_object
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Dataset(VideoDataset):
+    r"""Paired video dataset for use in few-shot vid2vid.
+
+    Args:
+        cfg (Config): Loaded config object.
+        is_inference (bool): In train or inference mode?
+        sequence_length (int): What sequence of images to provide?
+        few_shot_K (int): How many images to provide for few-shot?
+    """
+
+    def __init__(self, cfg, is_inference=False, sequence_length=None,
+                 few_shot_K=None, is_test=False):
+        self.paired = True
+        # Get initial few shot K.
+        if few_shot_K is None:
+            self.few_shot_K = cfg.data.initial_few_shot_K
+        else:
+            self.few_shot_K = few_shot_K
+        # Initialize.
+        super(Dataset, self).__init__(
+            cfg, is_inference, sequence_length=sequence_length, is_test=is_test)
+
+    def set_inference_sequence_idx(self, index, k_shot_index,
+                                   k_shot_frame_index):
+        r"""Get frames from this sequence during inference.
+
+        Args:
+            index (int): Index of inference sequence.
+            k_shot_index (int): Index of sequence from which k_shot is sampled.
+            k_shot_frame_index (int): Index of frame to sample.
+        """
+        assert self.is_inference
+        assert index < len(self.mapping)
+        assert k_shot_index < len(self.mapping)
+        assert k_shot_frame_index < len(self.mapping[k_shot_index])
+
+        self.inference_sequence_idx = index
+        self.inference_k_shot_sequence_index = k_shot_index
+        self.inference_k_shot_frame_index = k_shot_frame_index
+        self.epoch_length = len(
+            self.mapping[self.inference_sequence_idx]['filenames'])
+
+    def set_sequence_length(self, sequence_length, few_shot_K=None):
+        r"""Set the length of sequence you want as output from dataloader.
+
+        Args:
+            sequence_length (int): Length of output sequences.
+            few_shot_K (int): Number of few-shot frames.
+        """
+        if few_shot_K is None:
+            few_shot_K = self.few_shot_K
+        assert isinstance(sequence_length, int)
+        assert isinstance(few_shot_K, int)
+        if (sequence_length + few_shot_K) > self.sequence_length_max:
+            error_message = \
+                'Requested sequence length (%d) ' % (sequence_length) + \
+                '+ few shot K (%d) > ' % (few_shot_K) + \
+                'max sequence length (%d). ' % (self.sequence_length_max)
+            print(error_message)
+            sequence_length = self.sequence_length_max - few_shot_K
+            print('Reduced sequence length to %s' % (sequence_length))
+        self.sequence_length = sequence_length
+        self.few_shot_K = few_shot_K
+        # Recalculate mapping as some sequences might no longer be useful.
+        self.mapping, self.epoch_length = self._create_mapping()
+        print('Epoch length:', self.epoch_length)
+
+    def _create_mapping(self):
+        r"""Creates mapping from idx to key in LMDB.
+
+        Returns:
+            (tuple):
+              - self.mapping (dict): Dict of seq_len to list of sequences.
+              - self.epoch_length (int): Number of samples in an epoch.
+        """
+        # Create dict mapping length to sequence.
+        length_to_key, num_selected_seq = {}, 0
+        has_additional_lists = len(self.additional_lists) > 0
+        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+            for sequence_name, filenames in sequence_list.items():
+                if len(filenames) >= (self.sequence_length + self.few_shot_K):
+                    if len(filenames) not in length_to_key:
+                        length_to_key[len(filenames)] = []
+                    if has_additional_lists:
+                        obj_indices = self.additional_lists[lmdb_idx][
+                            sequence_name]
+                    else:
+                        obj_indices = [0 for _ in range(len(filenames))]
+                    length_to_key[len(filenames)].append({
+                        'lmdb_root': self.lmdb_roots[lmdb_idx],
+                        'lmdb_idx': lmdb_idx,
+                        'sequence_name': sequence_name,
+                        'filenames': filenames,
+                        'obj_indices': obj_indices,
+                    })
+                    num_selected_seq += 1
+        self.mapping = length_to_key
+        self.epoch_length = num_selected_seq
+
+        # At inference time, we want to use all sequences,
+        # irrespective of length.
+        if self.is_inference:
+            sequence_list = []
+            for key, sequences in self.mapping.items():
+                sequence_list.extend(sequences)
+            self.mapping = sequence_list
+
+        return self.mapping, self.epoch_length
+
+    def _sample_keys(self, index):
+        r"""Gets files to load for this sample.
+
+        Args:
+            index (int): Index in [0, len(dataset)].
+        Returns:
+            key (dict):
+                - lmdb_idx (int): Chosen LMDB dataset root.
+                - sequence_name (str): Chosen sequence in chosen dataset.
+                - filenames (list of str): Chosen filenames in chosen sequence.
+        """
+        if self.is_inference:
+            assert index < self.epoch_length
+            chosen_sequence = self.mapping[self.inference_sequence_idx]
+            chosen_filenames = [chosen_sequence['filenames'][index]]
+            chosen_obj_indices = [chosen_sequence['obj_indices'][index]]
+            k_shot_chosen_sequence = self.mapping[
+                self.inference_k_shot_sequence_index]
+            k_shot_chosen_filenames = [k_shot_chosen_sequence['filenames'][
+                                       self.inference_k_shot_frame_index]]
+            k_shot_chosen_obj_indices = [k_shot_chosen_sequence['obj_indices'][
+                                         self.inference_k_shot_frame_index]]
+            # Prepare few shot key.
+            few_shot_key = copy.deepcopy(k_shot_chosen_sequence)
+            few_shot_key['filenames'] = k_shot_chosen_filenames
+            few_shot_key['obj_indices'] = k_shot_chosen_obj_indices
+        else:
+            # Pick a time step for temporal augmentation.
+            time_step = random.randint(1, self.augmentor.max_time_step)
+            required_sequence_length = 1 + \
+                (self.sequence_length - 1) * time_step
+
+            # If step is too large, default to step size of 1.
+            if required_sequence_length + self.few_shot_K > \
+                    self.sequence_length_max:
+                required_sequence_length = self.sequence_length
+                time_step = 1
+
+            # Find valid sequences.
+            valid_sequences = []
+            for sequence_length, sequences in self.mapping.items():
+                if sequence_length >= required_sequence_length + \
+                        self.few_shot_K:
+                    valid_sequences.extend(sequences)
+
+            # Pick a sequence.
+            chosen_sequence = random.choice(valid_sequences)
+
+            # Choose filenames.
+            max_start_idx = len(chosen_sequence['filenames']) - \
+                required_sequence_length
+            start_idx = random.randint(0, max_start_idx)
+            end_idx = start_idx + required_sequence_length
+            chosen_filenames = chosen_sequence['filenames'][
+                start_idx:end_idx:time_step]
+            chosen_obj_indices = chosen_sequence['obj_indices'][
+                start_idx:end_idx:time_step]
+
+            # Find the K few shot filenames.
+            valid_range = list(range(start_idx)) + \
+                list(range(end_idx, len(chosen_sequence['filenames'])))
+            k_shot_chosen = sorted(random.sample(valid_range, self.few_shot_K))
+            k_shot_chosen_filenames = [chosen_sequence['filenames'][idx]
+                                       for idx in k_shot_chosen]
+            k_shot_chosen_obj_indices = [chosen_sequence['obj_indices'][idx]
+                                         for idx in k_shot_chosen]
+            assert not (set(chosen_filenames) & set(k_shot_chosen_filenames))
+
+            assert len(chosen_filenames) == self.sequence_length
+            assert len(k_shot_chosen_filenames) == self.few_shot_K
+
+            # Prepare few shot key.
+            few_shot_key = copy.deepcopy(chosen_sequence)
+            few_shot_key['filenames'] = k_shot_chosen_filenames
+            few_shot_key['obj_indices'] = k_shot_chosen_obj_indices
+
+        # Prepre output key.
+        key = copy.deepcopy(chosen_sequence)
+        key['filenames'] = chosen_filenames
+        key['obj_indices'] = chosen_obj_indices
+        return key, few_shot_key
+
+    def _prepare_data(self, keys):
+        r"""Load data and perform augmentation.
+
+        Args:
+            keys (dict): Key into LMDB/folder dataset for this item.
+        Returns:
+            data (dict): Dict with all chosen data_types.
+        """
+        # Unpack keys.
+        lmdb_idx = keys['lmdb_idx']
+        sequence_name = keys['sequence_name']
+        filenames = keys['filenames']
+        obj_indices = keys['obj_indices']
+
+        # Get key and lmdbs.
+        keys, lmdbs = {}, {}
+        for data_type in self.dataset_data_types:
+            keys[data_type] = self._create_sequence_keys(
+                sequence_name, filenames)
+            lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+        # Load all data for this index.
+        data = self.load_from_dataset(keys, lmdbs)
+
+        # Apply ops pre augmentation.
+        data = self.apply_ops(data, self.pre_aug_ops)
+
+        # Select the object in data using the object indices.
+        data = select_object(data, obj_indices)
+
+        # Do augmentations for images.
+        data, is_flipped = self.perform_augmentation(data, paired=True, augment_ops=self.augmentor.augment_ops)
+
+        # Create copy of keypoint data types before post aug.
+        # kp_data = {}
+        # for data_type in self.keypoint_data_types:
+        #     new_key = data_type + '_xy'
+        #     kp_data[new_key] = copy.deepcopy(data[data_type])
+
+        # Create copy of keypoint data types before post aug.
+        kp_data = {}
+        for data_type in self.keypoint_data_types:
+            new_key = data_type + '_xy'
+            kp_data[new_key] = copy.deepcopy(data[data_type])
+
+        # Apply ops post augmentation.
+        data = self.apply_ops(data, self.post_aug_ops)
+
+        data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)
+
+        # Convert images to tensor.
+        data = self.to_tensor(data)
+
+        # Pack the sequence of images.
+        for data_type in self.image_data_types:
+            for idx in range(len(data[data_type])):
+                data[data_type][idx] = data[data_type][idx].unsqueeze(0)
+            data[data_type] = torch.cat(data[data_type], dim=0)
+
+        # Add keypoint xy to data.
+        data.update(kp_data)
+
+        data['is_flipped'] = is_flipped
+        data['key'] = keys
+
+        return data
+
+    def _getitem(self, index):
+        r"""Gets selected files.
+
+        Args:
+            index (int): Index into dataset.
+        Returns:
+            data (dict): Dict with all chosen data_types.
+        """
+        # Select a sample from the available data.
+        keys, few_shot_keys = self._sample_keys(index)
+
+        data = self._prepare_data(keys)
+        few_shot_data = self._prepare_data(few_shot_keys)
+
+        # Add few shot data into data.
+        for key, value in few_shot_data.items():
+            data['few_shot_' + key] = few_shot_data[key]
+
+        # Apply full data ops.
+        if self.is_inference:
+            if index == 0:
+                pass
+            elif index < self.cfg.data.num_workers:
+                data_0 = self._getitem(0)
+                if 'common_attr' in data_0:
+                    self.common_attr = data['common_attr'] = \
+                        data_0['common_attr']
+            else:
+                if hasattr(self, 'common_attr'):
+                    data['common_attr'] = self.common_attr
+
+        data = self.apply_ops(data, self.full_data_ops, full_data=True)
+
+        if self.is_inference and index == 0 and 'common_attr' in data:
+            self.common_attr = data['common_attr']
+
+        return data
diff --git a/imaginaire/datasets/paired_few_shot_videos_native.py b/imaginaire/datasets/paired_few_shot_videos_native.py
new file mode 100644
index 0000000000000000000000000000000000000000..17bd23d8b1bd7b8d007f9a9bcf9f13ae6a16bdf0
--- /dev/null
+++ b/imaginaire/datasets/paired_few_shot_videos_native.py
@@ -0,0 +1,233 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import random
+import tempfile
+from collections import OrderedDict
+import warnings
+import numpy as np
+import torch
+# import torchvision.io as io
+import cv2
+from PIL import Image
+
+from imaginaire.datasets.base import BaseDataset
+
+
+class Dataset(BaseDataset):
+    r"""Dataset for paired few shot videos.
+
+    Args:
+        cfg (Config): Loaded config object.
+        is_inference (bool): In train or inference mode?
+    """
+
+    def __init__(self, cfg, is_inference=False, is_test=False):
+        self.paired = True
+        super(Dataset, self).__init__(cfg, is_inference, is_test)
+        self.is_video_dataset = True
+        self.few_shot_K = 1
+        self.first_last_only = getattr(cfg.data, 'first_last_only', False)
+        self.sample_far_frames_more = getattr(cfg.data, 'sample_far_frames_more', False)
+
+    def get_label_lengths(self):
+        r"""Get num channels of all labels to be concated.
+
+        Returns:
+            label_lengths (OrderedDict): Dict mapping image data_type to num
+            channels.
+        """
+        label_lengths = OrderedDict()
+        for data_type in self.input_labels:
+            data_cfg = self.cfgdata
+            if hasattr(data_cfg, 'one_hot_num_classes') and \
+                    data_type in data_cfg.one_hot_num_classes:
+                label_lengths[data_type] = data_cfg.one_hot_num_classes[data_type]
+                if getattr(data_cfg, 'use_dont_care', False):
+                    label_lengths[data_type] += 1
+            else:
+                label_lengths[data_type] = self.num_channels[data_type]
+        return label_lengths
+
+    def num_inference_sequences(self):
+        r"""Number of sequences available for inference.
+
+        Returns:
+           (int)
+        """
+        assert self.is_inference
+        return len(self.mapping)
+
+    def _create_mapping(self):
+        r"""Creates mapping from idx to key in LMDB.
+
+        Returns:
+            (tuple):
+              - self.mapping (dict): Dict of seq_len to list of sequences.
+              - self.epoch_length (int): Number of samples in an epoch.
+        """
+        # Create dict mapping length to sequence.
+        mapping = []
+        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+            for sequence_name, filenames in sequence_list.items():
+                for filename in filenames:
+                    # This file is corrupt.
+                    if filename == 'z-KziTO_5so_0019_start0_end85_h596_w596':
+                        continue
+                    mapping.append({
+                        'lmdb_root': self.lmdb_roots[lmdb_idx],
+                        'lmdb_idx': lmdb_idx,
+                        'sequence_name': sequence_name,
+                        'filenames': [filename],
+                    })
+        self.mapping = mapping
+        self.epoch_length = len(mapping)
+
+        return self.mapping, self.epoch_length
+
+    def _sample_keys(self, index):
+        r"""Gets files to load for this sample.
+
+        Args:
+            index (int): Index in [0, len(dataset)].
+        Returns:
+            (tuple):
+              - key (dict):
+                - lmdb_idx (int): Chosen LMDB dataset root.
+                - sequence_name (str): Chosen sequence in chosen dataset.
+                - filenames (list of str): Chosen filenames in chosen sequence.
+        """
+        if self.is_inference:
+            assert index < self.epoch_length
+            raise NotImplementedError
+        else:
+            # Select a video at random.
+            key = random.choice(self.mapping)
+        return key
+
+    def _create_sequence_keys(self, sequence_name, filenames):
+        r"""Create the LMDB key for this piece of information.
+
+        Args:
+            sequence_name (str): Which sequence from the chosen dataset.
+            filenames (list of str): List of filenames in this sequence.
+        Returns:
+            keys (list): List of full keys.
+        """
+        assert isinstance(filenames, list), 'Filenames should be a list.'
+        keys = []
+        for filename in filenames:
+            keys.append('%s/%s' % (sequence_name, filename))
+        return keys
+
+    def _getitem(self, index):
+        r"""Gets selected files.
+
+        Args:
+            index (int): Index into dataset.
+            concat (bool): Concatenate all items in labels?
+        Returns:
+            data (dict): Dict with all chosen data_types.
+        """
+        # Select a sample from the available data.
+        keys = self._sample_keys(index)
+
+        # Unpack keys.
+        lmdb_idx = keys['lmdb_idx']
+        sequence_name = keys['sequence_name']
+        filenames = keys['filenames']
+
+        # Get key and lmdbs.
+        keys, lmdbs = {}, {}
+        for data_type in self.dataset_data_types:
+            keys[data_type] = self._create_sequence_keys(
+                sequence_name, filenames)
+            lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+        # Load all data for this index.
+        data = self.load_from_dataset(keys, lmdbs)
+
+        # Get frames from video.
+        try:
+            temp = tempfile.NamedTemporaryFile()
+            temp.write(data['videos'][0])
+            temp.seek(0)
+
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                # frames, _, info = io.read_video(temp)
+                # num_frames = frames.size(0)
+                cap = cv2.VideoCapture(temp.name)
+                num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+            if self.first_last_only:
+                chosen_idxs = [0, num_frames - 1]
+            else:
+                # chosen_idxs = random.sample(range(frames.size(0)), 2)
+
+                chosen_idx = random.sample(range(num_frames), 1)[0]
+                few_shot_choose_range = list(range(chosen_idx)) + list(range(chosen_idx + 1, num_frames))
+                if self.sample_far_frames_more:
+                    choose_weight = list(reversed(range(chosen_idx))) + list(range(num_frames - chosen_idx - 1))
+                    few_shot_idx = random.choices(few_shot_choose_range, choose_weight, k=self.few_shot_K)
+                else:
+                    few_shot_idx = random.sample(few_shot_choose_range, k=self.few_shot_K)
+                chosen_idxs = few_shot_idx + [chosen_idx]
+
+            chosen_images = []
+            for idx in chosen_idxs:
+                # chosen_images.append(Image.fromarray(frames[idx].numpy()))
+                cap.set(1, idx)
+                _, frame = cap.read()
+                chosen_images.append(Image.fromarray(frame[:, :, ::-1]))
+        except Exception:
+            print('Issue with file:', sequence_name, filenames)
+            blank = np.zeros((512, 512, 3), dtype=np.uint8)
+            chosen_images = [Image.fromarray(blank), Image.fromarray(blank)]
+
+        data['videos'] = chosen_images
+
+        # Apply ops pre augmentation.
+        data = self.apply_ops(data, self.pre_aug_ops)
+
+        # Do augmentations for images.
+        data, is_flipped = self.perform_augmentation(
+            data, paired=True, augment_ops=self.augmentor.augment_ops)
+        # Individual video frame augmentation is used in face-vid2vid.
+        data = self.perform_individual_video_frame(
+            data, self.augmentor.individual_video_frame_augmentation_ops)
+
+        # Apply ops post augmentation.
+        data = self.apply_ops(data, self.post_aug_ops)
+
+        # Convert images to tensor.
+        data = self.to_tensor(data)
+
+        # Pack the sequence of images.
+        for data_type in self.image_data_types:
+            for idx in range(len(data[data_type])):
+                data[data_type][idx] = data[data_type][idx].unsqueeze(0)
+            data[data_type] = torch.cat(data[data_type], dim=0)
+
+        if not self.is_video_dataset:
+            # Remove any extra dimensions.
+            for data_type in self.image_data_types:
+                if data_type in data:
+                    data[data_type] = data[data_type].squeeze(0)
+
+        # Prepare output.
+        data['driving_images'] = data['videos'][self.few_shot_K:]
+        data['source_images'] = data['videos'][:self.few_shot_K]
+        data.pop('videos')
+        data['is_flipped'] = is_flipped
+        data['key'] = keys
+        data['original_h_w'] = torch.IntTensor([
+            self.augmentor.original_h, self.augmentor.original_w])
+
+        # Apply full data ops.
+        data = self.apply_ops(data, self.full_data_ops, full_data=True)
+
+        return data
+
+    def __getitem__(self, index):
+        return self._getitem(index)
diff --git a/imaginaire/datasets/paired_images.py b/imaginaire/datasets/paired_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3112a68de57e7c85d031ee690b62ac1fbc4f96d
--- /dev/null
+++ b/imaginaire/datasets/paired_images.py
@@ -0,0 +1,87 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+from imaginaire.datasets.paired_videos import Dataset as VideoDataset
+
+
+class Dataset(VideoDataset):
+    r"""Paired image dataset for use in pix2pixHD, SPADE.
+
+    Args:
+        cfg (Config): Loaded config object.
+        is_inference (bool): In train or inference mode?
+    """
+
+    def __init__(self, cfg, is_inference=False, is_test=False):
+        self.paired = True
+        super(Dataset, self).__init__(cfg, is_inference,
+                                      sequence_length=1,
+                                      is_test=is_test)
+        self.is_video_dataset = False
+
+    def _create_mapping(self):
+        r"""Creates mapping from idx to key in LMDB.
+
+        Returns:
+            (tuple):
+              - self.mapping (list): List mapping idx to key.
+              - self.epoch_length (int): Number of samples in an epoch.
+        """
+        idx_to_key = []
+        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+            for sequence_name, filenames in sequence_list.items():
+                for filename in filenames:
+                    idx_to_key.append({
+                        'lmdb_root': self.lmdb_roots[lmdb_idx],
+                        'lmdb_idx': lmdb_idx,
+                        'sequence_name': sequence_name,
+                        'filenames': [filename],
+                    })
+        self.mapping = idx_to_key
+        self.epoch_length = len(self.mapping)
+        return self.mapping, self.epoch_length
+
+    def _sample_keys(self, index):
+        r"""Gets files to load for this sample.
+
+        Args:
+            index (int): Index in [0, len(dataset)].
+        Returns:
+            key (dict):
+              - lmdb_idx (int): Chosen LMDB dataset root.
+              - sequence_name (str): Chosen sequence in chosen dataset.
+              - filenames (list of str): Chosen filenames in chosen sequence.
+        """
+        assert self.sequence_length == 1, \
+            'Image dataset can only have sequence length = 1, not %d' % (
+                self.sequence_length)
+        return self.mapping[index]
+
+    def set_sequence_length(self, sequence_length):
+        r"""Set the length of sequence you want as output from dataloader.
+        Ignore this as this is an image loader.
+
+        Args:
+            sequence_length (int): Length of output sequences.
+        """
+        pass
+
+    def set_inference_sequence_idx(self, index):
+        r"""Get frames from this sequence during inference.
+        Overriden from super as this is not applicable for images.
+
+        Args:
+            index (int): Index of inference sequence.
+        """
+        raise RuntimeError('Image dataset does not have sequences.')
+
+    def num_inference_sequences(self):
+        r"""Number of sequences available for inference.
+        Overriden from super as this is not applicable for images.
+
+        Returns:
+            (int)
+        """
+        raise RuntimeError('Image dataset does not have sequences.')
diff --git a/imaginaire/datasets/paired_videos.py b/imaginaire/datasets/paired_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e5c645f9de00184fa6cf7cc0b6b910c5815417
--- /dev/null
+++ b/imaginaire/datasets/paired_videos.py
@@ -0,0 +1,288 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import copy
+import random
+from collections import OrderedDict
+
+import torch
+
+from imaginaire.datasets.base import BaseDataset
+from imaginaire.model_utils.fs_vid2vid import select_object
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Dataset(BaseDataset):
+    r"""Paired video dataset for use in vid2vid, wc_vid2vid.
+
+    Args:
+        cfg (Config): Loaded config object.
+        is_inference (bool): In train or inference mode?
+        sequence_length (int): What sequence of images to provide?
+    """
+
+    def __init__(self, cfg,
+                 is_inference=False,
+                 sequence_length=None,
+                 is_test=False):
+        self.paired = True
+        # Get initial sequence length.
+        if sequence_length is None and not is_inference:
+            self.sequence_length = cfg.data.train.initial_sequence_length
+        elif sequence_length is None and is_inference:
+            self.sequence_length = 2
+        else:
+            self.sequence_length = sequence_length
+        super(Dataset, self).__init__(cfg, is_inference, is_test)
+        self.set_sequence_length(self.sequence_length)
+        self.is_video_dataset = True
+
+    def get_label_lengths(self):
+        r"""Get num channels of all labels to be concated.
+
+        Returns:
+            label_lengths (OrderedDict): Dict mapping image data_type to num
+            channels.
+        """
+        label_lengths = OrderedDict()
+        for data_type in self.input_labels:
+            data_cfg = self.cfgdata
+            if hasattr(data_cfg, 'one_hot_num_classes') and data_type in data_cfg.one_hot_num_classes:
+                label_lengths[data_type] = data_cfg.one_hot_num_classes[data_type]
+                if getattr(data_cfg, 'use_dont_care', False):
+                    label_lengths[data_type] += 1
+            else:
+                label_lengths[data_type] = self.num_channels[data_type]
+        return label_lengths
+
+    def num_inference_sequences(self):
+        r"""Number of sequences available for inference.
+
+        Returns:
+            (int)
+        """
+        assert self.is_inference
+        return len(self.mapping)
+
+    def set_inference_sequence_idx(self, index):
+        r"""Get frames from this sequence during inference.
+
+        Args:
+            index (int): Index of inference sequence.
+        """
+        assert self.is_inference
+        assert index < len(self.mapping)
+        self.inference_sequence_idx = index
+        self.epoch_length = len(
+            self.mapping[self.inference_sequence_idx]['filenames'])
+
+    def set_sequence_length(self, sequence_length):
+        r"""Set the length of sequence you want as output from dataloader.
+
+        Args:
+            sequence_length (int): Length of output sequences.
+        """
+        assert isinstance(sequence_length, int)
+        if sequence_length > self.sequence_length_max:
+            print('Requested sequence length (%d) > ' % (sequence_length) +
+                  'max sequence length (%d). ' % (self.sequence_length_max) +
+                  'Limiting sequence length to max sequence length.')
+            sequence_length = self.sequence_length_max
+        self.sequence_length = sequence_length
+        # Recalculate mapping as some sequences might no longer be useful.
+        self.mapping, self.epoch_length = self._create_mapping()
+        print('Epoch length:', self.epoch_length)
+
+    def _compute_dataset_stats(self):
+        r"""Compute statistics of video sequence dataset.
+
+        Returns:
+            sequence_length_max (int): Maximum sequence length.
+        """
+        print('Num datasets:', len(self.sequence_lists))
+
+        if self.sequence_length >= 1:
+            num_sequences, sequence_length_max = 0, 0
+            for sequence in self.sequence_lists:
+                for _, filenames in sequence.items():
+                    sequence_length_max = max(
+                        sequence_length_max, len(filenames))
+                    num_sequences += 1
+            print('Num sequences:', num_sequences)
+            print('Max sequence length:', sequence_length_max)
+            self.sequence_length_max = sequence_length_max
+
+    def _create_mapping(self):
+        r"""Creates mapping from idx to key in LMDB.
+
+        Returns:
+            (tuple):
+              - self.mapping (dict): Dict of seq_len to list of sequences.
+              - self.epoch_length (int): Number of samples in an epoch.
+        """
+        # Create dict mapping length to sequence.
+        length_to_key, num_selected_seq = {}, 0
+        total_num_of_frames = 0
+        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+            for sequence_name, filenames in sequence_list.items():
+                if len(filenames) >= self.sequence_length:
+                    total_num_of_frames += len(filenames)
+                    if len(filenames) not in length_to_key:
+                        length_to_key[len(filenames)] = []
+                    length_to_key[len(filenames)].append({
+                        'lmdb_root': self.lmdb_roots[lmdb_idx],
+                        'lmdb_idx': lmdb_idx,
+                        'sequence_name': sequence_name,
+                        'filenames': filenames,
+                    })
+                    num_selected_seq += 1
+        self.mapping = length_to_key
+        self.epoch_length = num_selected_seq
+        if not self.is_inference and self.epoch_length < \
+                self.cfgdata.train.batch_size * 8:
+            self.epoch_length = total_num_of_frames
+
+        # At inference time, we want to use all sequences,
+        # irrespective of length.
+        if self.is_inference:
+            sequence_list = []
+            for key, sequences in self.mapping.items():
+                sequence_list.extend(sequences)
+            self.mapping = sequence_list
+
+        return self.mapping, self.epoch_length
+
+    def _sample_keys(self, index):
+        r"""Gets files to load for this sample.
+
+        Args:
+            index (int): Index in [0, len(dataset)].
+        Returns:
+            key (dict):
+              - lmdb_idx (int): Chosen LMDB dataset root.
+              - sequence_name (str): Chosen sequence in chosen dataset.
+              - filenames (list of str): Chosen filenames in chosen sequence.
+        """
+        if self.is_inference:
+            assert index < self.epoch_length
+            chosen_sequence = self.mapping[self.inference_sequence_idx]
+            chosen_filenames = [chosen_sequence['filenames'][index]]
+        else:
+            # Pick a time step for temporal augmentation.
+            time_step = random.randint(1, self.augmentor.max_time_step)
+            required_sequence_length = 1 + \
+                (self.sequence_length - 1) * time_step
+
+            # If step is too large, default to step size of 1.
+            if required_sequence_length > self.sequence_length_max:
+                required_sequence_length = self.sequence_length
+                time_step = 1
+
+            # Find valid sequences.
+            valid_sequences = []
+            for sequence_length, sequences in self.mapping.items():
+                if sequence_length >= required_sequence_length:
+                    valid_sequences.extend(sequences)
+
+            # Pick a sequence.
+            chosen_sequence = random.choice(valid_sequences)
+
+            # Choose filenames.
+            max_start_idx = len(chosen_sequence['filenames']) - \
+                required_sequence_length
+            start_idx = random.randint(0, max_start_idx)
+
+            chosen_filenames = chosen_sequence['filenames'][
+                start_idx:start_idx + required_sequence_length:time_step]
+            assert len(chosen_filenames) == self.sequence_length
+
+        # Prepre output key.
+        key = copy.deepcopy(chosen_sequence)
+        key['filenames'] = chosen_filenames
+        return key
+
+    def _create_sequence_keys(self, sequence_name, filenames):
+        r"""Create the LMDB key for this piece of information.
+
+        Args:
+            sequence_name (str): Which sequence from the chosen dataset.
+            filenames (list of str): List of filenames in this sequence.
+        Returns:
+            keys (list): List of full keys.
+        """
+        assert isinstance(filenames, list), 'Filenames should be a list.'
+        keys = []
+        if sequence_name.endswith('___') and sequence_name[-9:-6] == '___':
+            sequence_name = sequence_name[:-9]
+        for filename in filenames:
+            keys.append('%s/%s' % (sequence_name, filename))
+        return keys
+
+    def _getitem(self, index):
+        r"""Gets selected files.
+
+        Args:
+            index (int): Index into dataset.
+            concat (bool): Concatenate all items in labels?
+        Returns:
+            data (dict): Dict with all chosen data_types.
+        """
+        # Select a sample from the available data.
+        keys = self._sample_keys(index)
+
+        # Unpack keys.
+        lmdb_idx = keys['lmdb_idx']
+        sequence_name = keys['sequence_name']
+        filenames = keys['filenames']
+
+        # Get key and lmdbs.
+        keys, lmdbs = {}, {}
+        for data_type in self.dataset_data_types:
+            keys[data_type] = self._create_sequence_keys(
+                sequence_name, filenames)
+            lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+        # Load all data for this index.
+        data = self.load_from_dataset(keys, lmdbs)
+
+        # Apply ops pre augmentation.
+        data = self.apply_ops(data, self.pre_aug_ops)
+
+        # If multiple subjects exist in the data, only pick one to synthesize.
+        data = select_object(data, obj_indices=None)
+
+        # Do augmentations for images.
+        data, is_flipped = self.perform_augmentation(data, paired=True, augment_ops=self.augmentor.augment_ops)
+
+        # Apply ops post augmentation.
+        data = self.apply_ops(data, self.post_aug_ops)
+        data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)
+
+        # Convert images to tensor.
+        data = self.to_tensor(data)
+
+        # Pack the sequence of images.
+        for data_type in self.image_data_types + self.hdr_image_data_types:
+            for idx in range(len(data[data_type])):
+                data[data_type][idx] = data[data_type][idx].unsqueeze(0)
+            data[data_type] = torch.cat(data[data_type], dim=0)
+
+        if not self.is_video_dataset:
+            # Remove any extra dimensions.
+            for data_type in self.data_types:
+                if data_type in data:
+                    data[data_type] = data[data_type].squeeze(0)
+
+        data['is_flipped'] = is_flipped
+        data['key'] = keys
+        data['original_h_w'] = torch.IntTensor([
+            self.augmentor.original_h, self.augmentor.original_w])
+
+        # Apply full data ops.
+        data = self.apply_ops(data, self.full_data_ops, full_data=True)
+
+        return data
+
+    def __getitem__(self, index):
+        return self._getitem(index)
diff --git a/imaginaire/datasets/unpaired_few_shot_images.py b/imaginaire/datasets/unpaired_few_shot_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8fa1e5ec2cd528c72effd065338fa693d1319c8
--- /dev/null
+++ b/imaginaire/datasets/unpaired_few_shot_images.py
@@ -0,0 +1,182 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import random
+
+from imaginaire.datasets.base import BaseDataset
+
+
+class Dataset(BaseDataset):
+    r"""Image dataset for use in FUNIT.
+
+    Args:
+        cfg (Config): Loaded config object.
+        is_inference (bool): In train or inference mode?
+    """
+
+    def __init__(self, cfg, is_inference=False, is_test=False):
+        self.paired = False
+        super(Dataset, self).__init__(cfg, is_inference, is_test)
+        self.num_content_classes = len(self.class_name_to_idx['images_content'])
+        self.num_style_classes = len(self.class_name_to_idx['images_style'])
+        self.sample_class_idx = None
+        self.content_offset = 8888
+        self.content_interval = 100
+
+    def set_sample_class_idx(self, class_idx=None):
+        r"""Set sample class idx.
+
+        Args:
+            class_idx (int): Which class idx to sample from.
+        """
+        self.sample_class_idx = class_idx
+        if class_idx is None:
+            self.epoch_length = \
+                max([len(lmdb_keys) for _, lmdb_keys in self.mapping.items()])
+        else:
+            self.epoch_length = \
+                len(self.mapping_class['images_style'][class_idx])
+
+    def _create_mapping(self):
+        r"""Creates mapping from idx to key in LMDB.
+
+        Returns:
+            (tuple):
+              - self.mapping (dict): Dict with data type as key mapping idx to
+              LMDB key.
+              - self.epoch_length (int): Number of samples in an epoch.
+        """
+        idx_to_key, class_names = {}, {}
+        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+            for data_type, data_type_sequence_list in sequence_list.items():
+                class_names[data_type] = []
+                if data_type not in idx_to_key:
+                    idx_to_key[data_type] = []
+                for sequence_name, filenames in data_type_sequence_list.items():
+                    class_name = sequence_name.split('/')[0]
+                    for filename in filenames:
+                        idx_to_key[data_type].append({
+                            'lmdb_root': self.lmdb_roots[lmdb_idx],
+                            'lmdb_idx': lmdb_idx,
+                            'sequence_name': sequence_name,
+                            'filename': filename,
+                            'class_name': class_name
+                        })
+                    class_names[data_type].append(class_name)
+        self.mapping = idx_to_key
+        self.epoch_length = max([len(lmdb_keys)
+                                 for _, lmdb_keys in self.mapping.items()])
+
+        # Create mapping from class name to class idx.
+        self.class_name_to_idx = {}
+        for data_type, class_names_data_type in class_names.items():
+            self.class_name_to_idx[data_type] = {}
+            class_names_data_type = sorted(list(set(class_names_data_type)))
+            for class_idx, class_name in enumerate(class_names_data_type):
+                self.class_name_to_idx[data_type][class_name] = class_idx
+
+        # Add class idx to mapping.
+        for data_type in self.mapping:
+            for key in self.mapping[data_type]:
+                key['class_idx'] = \
+                    self.class_name_to_idx[data_type][key['class_name']]
+
+        # Create a mapping from index to lmdb key for each class.
+        idx_to_key_class = {}
+        for data_type in self.mapping:
+            idx_to_key_class[data_type] = {}
+            for class_idx, class_name in enumerate(class_names[data_type]):
+                idx_to_key_class[data_type][class_idx] = []
+            for key in self.mapping[data_type]:
+                idx_to_key_class[data_type][key['class_idx']].append(key)
+        self.mapping_class = idx_to_key_class
+
+        return self.mapping, self.epoch_length
+
+    def _sample_keys(self, index):
+        r"""Gets files to load for this sample.
+
+        Args:
+            index (int): Index in [0, len(dataset)].
+        Returns:
+            (tuple):
+              - keys (dict): Each key of this dict is a data type.
+              - lmdb_key (dict):
+                - lmdb_idx (int): Chosen LMDB dataset root.
+                - sequence_name (str): Chosen sequence in chosen dataset.
+                - filename (str): Chosen filename in chosen sequence.
+        """
+
+        keys = {}
+        if self.is_inference:  # evaluation mode
+            lmdb_keys_content = self.mapping['images_content']
+            keys['images_content'] = \
+                lmdb_keys_content[
+                    ((index + self.content_offset * self.sample_class_idx) *
+                     self.content_interval) % len(lmdb_keys_content)]
+
+            lmdb_keys_style = \
+                self.mapping_class['images_style'][self.sample_class_idx]
+            keys['images_style'] = lmdb_keys_style[index]
+        else:
+            lmdb_keys_content = self.mapping['images_content']
+            lmdb_keys_style = self.mapping['images_style']
+            keys['images_content'] = random.choice(lmdb_keys_content)
+            keys['images_style'] = random.choice(lmdb_keys_style)
+        return keys
+
+    def __getitem__(self, index):
+        r"""Gets selected files.
+
+        Args:
+            index (int): Index into dataset.
+            concat (bool): Concatenate all items in labels?
+        Returns:
+            data (dict): Dict with all chosen data_types.
+        """
+        # Select a sample from the available data.
+        keys_per_data_type = self._sample_keys(index)
+
+        # Get class idx into a list.
+        class_idxs = []
+        for data_type in keys_per_data_type:
+            class_idxs.append(keys_per_data_type[data_type]['class_idx'])
+
+        # Get keys and lmdbs.
+        keys, lmdbs = {}, {}
+        for data_type in self.dataset_data_types:
+            # Unpack keys.
+            lmdb_idx = keys_per_data_type[data_type]['lmdb_idx']
+            sequence_name = keys_per_data_type[data_type]['sequence_name']
+            filename = keys_per_data_type[data_type]['filename']
+            keys[data_type] = '%s/%s' % (sequence_name, filename)
+            lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+        # Load all data for this index.
+        data = self.load_from_dataset(keys, lmdbs)
+
+        # Apply ops pre augmentation.
+        data = self.apply_ops(data, self.pre_aug_ops)
+
+        # Do augmentations for images.
+        data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops)
+
+        # Apply ops post augmentation.
+        data = self.apply_ops(data, self.post_aug_ops)
+        data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)
+
+        # Convert images to tensor.
+        data = self.to_tensor(data)
+
+        # Remove any extra dimensions.
+        for data_type in self.image_data_types:
+            data[data_type] = data[data_type][0]
+
+        # Package output.
+        data['is_flipped'] = is_flipped
+        data['key'] = keys_per_data_type
+        data['labels_content'] = class_idxs[0]
+        data['labels_style'] = class_idxs[1]
+
+        return data
diff --git a/imaginaire/datasets/unpaired_images.py b/imaginaire/datasets/unpaired_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a49a876705771a2d7a3f836bdbee2cdd328c10
--- /dev/null
+++ b/imaginaire/datasets/unpaired_images.py
@@ -0,0 +1,118 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import random
+
+from imaginaire.datasets.base import BaseDataset
+
+
+class Dataset(BaseDataset):
+    r"""Unpaired image dataset for use in MUNIT.
+
+    Args:
+        cfg (Config): Loaded config object.
+        is_inference (bool): In train or inference mode?
+    """
+
+    def __init__(self, cfg, is_inference=False, is_test=False):
+        self.paired = False
+        super(Dataset, self).__init__(cfg, is_inference, is_test)
+
+    def _create_mapping(self):
+        r"""Creates mapping from idx to key in LMDB.
+
+        Returns:
+            (tuple):
+              - self.mapping (dict): Dict with data type as key mapping idx to
+              LMDB key.
+              - self.epoch_length (int): Number of samples in an epoch.
+        """
+        idx_to_key = {}
+        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+            for data_type, data_type_sequence_list in sequence_list.items():
+                if data_type not in idx_to_key:
+                    idx_to_key[data_type] = []
+                for sequence_name, filenames in data_type_sequence_list.items():
+                    for filename in filenames:
+                        idx_to_key[data_type].append({
+                            'lmdb_root': self.lmdb_roots[lmdb_idx],
+                            'lmdb_idx': lmdb_idx,
+                            'sequence_name': sequence_name,
+                            'filename': filename,
+                        })
+        self.mapping = idx_to_key
+        self.epoch_length = max([len(lmdb_keys)
+                                 for _, lmdb_keys in self.mapping.items()])
+        return self.mapping, self.epoch_length
+
+    def _sample_keys(self, index):
+        r"""Gets files to load for this sample.
+
+        Args:
+            index (int): Index in [0, len(dataset)].
+        Returns:
+            keys (dict): Each key of this dict is a data type.
+                lmdb_key (dict):
+                    lmdb_idx (int): Chosen LMDB dataset root.
+                    sequence_name (str): Chosen sequence in chosen dataset.
+                    filename (str): Chosen filename in chosen sequence.
+        """
+        keys = {}
+        for data_type in self.data_types:
+            lmdb_keys = self.mapping[data_type]
+            if self.is_inference:
+                # Modulo ensures valid indexing in case A and B have different
+                # number of files.
+                keys[data_type] = lmdb_keys[index % len(lmdb_keys)]
+            else:
+                keys[data_type] = random.choice(lmdb_keys)
+        return keys
+
+    def __getitem__(self, index):
+        r"""Gets selected files.
+
+        Args:
+            index (int): Index into dataset.
+            concat (bool): Concatenate all items in labels?
+        Returns:
+            data (dict): Dict with all chosen data_types.
+        """
+        # Select a sample from the available data.
+        keys_per_data_type = self._sample_keys(index)
+
+        # Get keys and lmdbs.
+        keys, lmdbs = {}, {}
+        for data_type in self.dataset_data_types:
+            # Unpack keys.
+            lmdb_idx = keys_per_data_type[data_type]['lmdb_idx']
+            sequence_name = keys_per_data_type[data_type]['sequence_name']
+            filename = keys_per_data_type[data_type]['filename']
+            keys[data_type] = '%s/%s' % (sequence_name, filename)
+            lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+        # Load all data for this index.
+        data = self.load_from_dataset(keys, lmdbs)
+
+        # Apply ops pre augmentation.
+        data = self.apply_ops(data, self.pre_aug_ops)
+
+        # Do augmentations for images.
+        data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops)
+
+        # Apply ops post augmentation.
+        data = self.apply_ops(data, self.post_aug_ops)
+        data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)
+
+        # Convert images to tensor.
+        data = self.to_tensor(data)
+
+        # Remove any extra dimensions.
+        for data_type in self.image_data_types:
+            data[data_type] = data[data_type][0]
+
+        # Package output.
+        data['is_flipped'] = is_flipped
+        data['key'] = keys_per_data_type
+
+        return data
diff --git a/imaginaire/discriminators/__init__.py b/imaginaire/discriminators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/discriminators/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/discriminators/dummy.py b/imaginaire/discriminators/dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..a345806f6f844fbcf0c9da1915f4db5b2fa3d587
--- /dev/null
+++ b/imaginaire/discriminators/dummy.py
@@ -0,0 +1,29 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.nn as nn
+
+from imaginaire.layers import LinearBlock
+
+
+class Discriminator(nn.Module):
+    """Dummy Discriminator constructor.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super(Discriminator, self).__init__()
+        self.dummy_layer = LinearBlock(1, 1)
+        pass
+
+    def forward(self, data):
+        """Dummy discriminator forward.
+
+        Args:
+            data (dict):
+        """
+        return
diff --git a/imaginaire/discriminators/fpse.py b/imaginaire/discriminators/fpse.py
new file mode 100644
index 0000000000000000000000000000000000000000..231b666bfd3970c760df7ce5d1174193ad9a7708
--- /dev/null
+++ b/imaginaire/discriminators/fpse.py
@@ -0,0 +1,132 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.layers import Conv2dBlock
+
+
+class FPSEDiscriminator(nn.Module):
+    r"""# Feature-Pyramid Semantics Embedding Discriminator. This is a copy
+    of the discriminator in https://arxiv.org/pdf/1910.06809.pdf
+    """
+
+    def __init__(self,
+                 num_input_channels,
+                 num_labels,
+                 num_filters,
+                 kernel_size,
+                 weight_norm_type,
+                 activation_norm_type):
+        super().__init__()
+        padding = int(np.ceil((kernel_size - 1.0) / 2))
+        nonlinearity = 'leakyrelu'
+        stride1_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              stride=1,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              # inplace_nonlinearity=True,
+                              order='CNA')
+        down_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              stride=2,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              # inplace_nonlinearity=True,
+                              order='CNA')
+        latent_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=1,
+                              stride=1,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              # inplace_nonlinearity=True,
+                              order='CNA')
+        # bottom-up pathway
+
+        self.enc1 = down_conv2d_block(num_input_channels, num_filters)
+        self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters)
+        self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters)
+        self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters)
+        self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters)
+
+        # top-down pathway
+        self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
+        self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
+        self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
+        self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
+
+        # upsampling
+        self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear',
+                                      align_corners=False)
+
+        # final layers
+        self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+        self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+        self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+
+        # true/false prediction and semantic alignment prediction
+        self.output = Conv2dBlock(num_filters * 2, 1, kernel_size=1)
+        self.seg = Conv2dBlock(num_filters * 2, num_filters * 2, kernel_size=1)
+        self.embedding = Conv2dBlock(num_labels, num_filters * 2, kernel_size=1)
+
+    def forward(self, images, segmaps):
+        r"""
+
+        Args:
+            images: image tensors.
+            segmaps: segmentation map tensors.
+        """
+        # bottom-up pathway
+        feat11 = self.enc1(images)
+        feat12 = self.enc2(feat11)
+        feat13 = self.enc3(feat12)
+        feat14 = self.enc4(feat13)
+        feat15 = self.enc5(feat14)
+        # top-down pathway and lateral connections
+        feat25 = self.lat5(feat15)
+        feat24 = self.upsample2x(feat25) + self.lat4(feat14)
+        feat23 = self.upsample2x(feat24) + self.lat3(feat13)
+        feat22 = self.upsample2x(feat23) + self.lat2(feat12)
+        # final prediction layers
+        feat32 = self.final2(feat22)
+        feat33 = self.final3(feat23)
+        feat34 = self.final4(feat24)
+        # Patch-based True/False prediction
+        pred2 = self.output(feat32)
+        pred3 = self.output(feat33)
+        pred4 = self.output(feat34)
+        seg2 = self.seg(feat32)
+        seg3 = self.seg(feat33)
+        seg4 = self.seg(feat34)
+
+        # # segmentation map embedding
+        segembs = self.embedding(segmaps)
+        segembs = F.avg_pool2d(segembs, kernel_size=2, stride=2)
+        segembs2 = F.avg_pool2d(segembs, kernel_size=2, stride=2)
+        segembs3 = F.avg_pool2d(segembs2, kernel_size=2, stride=2)
+        segembs4 = F.avg_pool2d(segembs3, kernel_size=2, stride=2)
+
+        # semantics embedding discriminator score
+        pred2 += torch.mul(segembs2, seg2).sum(dim=1, keepdim=True)
+        pred3 += torch.mul(segembs3, seg3).sum(dim=1, keepdim=True)
+        pred4 += torch.mul(segembs4, seg4).sum(dim=1, keepdim=True)
+
+        # concat results from multiple resolutions
+        # results = [pred2, pred3, pred4]
+
+        return pred2, pred3, pred4
diff --git a/imaginaire/discriminators/fs_vid2vid.py b/imaginaire/discriminators/fs_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..78e29f64a301864f03238db7bd1ad444bee9773e
--- /dev/null
+++ b/imaginaire/discriminators/fs_vid2vid.py
@@ -0,0 +1,318 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import importlib
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator
+from imaginaire.model_utils.fs_vid2vid import get_fg_mask, pick_image
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+                                   get_paired_input_label_channel_number)
+from imaginaire.utils.misc import get_nested_attr
+
+
+class Discriminator(nn.Module):
+    r"""Image and video discriminator constructor.
+
+    Args:
+        dis_cfg (obj): Discriminator part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super().__init__()
+        self.data_cfg = data_cfg
+        num_input_channels = get_paired_input_label_channel_number(data_cfg)
+        if num_input_channels == 0:
+            num_input_channels = getattr(data_cfg, 'label_channels', 1)
+        num_img_channels = get_paired_input_image_channel_number(data_cfg)
+        self.num_frames_D = data_cfg.num_frames_D
+        self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0)
+        num_netD_input_channels = (num_input_channels + num_img_channels)
+        self.use_few_shot = 'few_shot' in data_cfg.type
+        if self.use_few_shot:
+            num_netD_input_channels *= 2
+        self.net_D = MultiPatchDiscriminator(dis_cfg.image,
+                                             num_netD_input_channels)
+
+        self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None)
+        if self.add_dis_cfg is not None:
+            for name in self.add_dis_cfg:
+                add_dis_cfg = self.add_dis_cfg[name]
+                num_ch = num_img_channels * (2 if self.use_few_shot else 1)
+                setattr(self, 'net_D_' + name,
+                        MultiPatchDiscriminator(add_dis_cfg, num_ch))
+
+        # Temporal discriminator.
+        self.num_netDT_input_channels = num_img_channels * self.num_frames_D
+        for n in range(self.num_scales):
+            setattr(self, 'net_DT%d' % n,
+                    MultiPatchDiscriminator(dis_cfg.temporal,
+                                            self.num_netDT_input_channels))
+        self.has_fg = getattr(data_cfg, 'has_foreground', False)
+
+    def forward(self, data, net_G_output, past_frames):
+        r"""Discriminator forward.
+
+        Args:
+            data (dict): Input data.
+            net_G_output (dict): Generator output.
+            past_frames (list of tensors): Past real frames / generator outputs.
+        Returns:
+            (tuple):
+              - output (dict): Discriminator output.
+              - past_frames (list of tensors): New past frames by adding
+                current outputs.
+        """
+        label, real_image = data['label'], data['image']
+        # Only operate on the latest output frame.
+        if label.dim() == 5:
+            label = label[:, -1]
+        if self.use_few_shot:
+            # Pick only one reference image to concat with.
+            ref_idx = net_G_output['ref_idx'] \
+                if 'ref_idx' in net_G_output else 0
+            ref_label = pick_image(data['ref_labels'], ref_idx)
+            ref_image = pick_image(data['ref_images'], ref_idx)
+            # Concat references with label map as discriminator input.
+            label = torch.cat([label, ref_label, ref_image], dim=1)
+        fake_image = net_G_output['fake_images']
+        output = dict()
+
+        # Individual frame loss.
+        pred_real, pred_fake = self.discrminate_image(self.net_D, label,
+                                                      real_image, fake_image)
+        output['indv'] = dict()
+        output['indv']['pred_real'] = pred_real
+        output['indv']['pred_fake'] = pred_fake
+
+        if 'fake_raw_images' in net_G_output and \
+                net_G_output['fake_raw_images'] is not None:
+            # Raw generator output loss.
+            fake_raw_image = net_G_output['fake_raw_images']
+            fg_mask = get_fg_mask(data['label'], self.has_fg)
+            pred_real, pred_fake = self.discrminate_image(
+                self.net_D, label,
+                real_image * fg_mask,
+                fake_raw_image * fg_mask)
+            output['raw'] = dict()
+            output['raw']['pred_real'] = pred_real
+            output['raw']['pred_fake'] = pred_fake
+
+        # Additional GAN loss on specific regions.
+        if self.add_dis_cfg is not None:
+            for name in self.add_dis_cfg:
+                # Crop corresponding regions in the image according to the
+                # crop function.
+                add_dis_cfg = self.add_dis_cfg[name]
+                file, crop_func = add_dis_cfg.crop_func.split('::')
+                file = importlib.import_module(file)
+                crop_func = getattr(file, crop_func)
+
+                real_crop = crop_func(self.data_cfg, real_image, label)
+                fake_crop = crop_func(self.data_cfg, fake_image, label)
+                if self.use_few_shot:
+                    ref_crop = crop_func(self.data_cfg, ref_image, label)
+                    if ref_crop is not None:
+                        real_crop = torch.cat([real_crop, ref_crop], dim=1)
+                        fake_crop = torch.cat([fake_crop, ref_crop], dim=1)
+
+                # Feed the crops to specific discriminator.
+                if fake_crop is not None:
+                    net_D = getattr(self, 'net_D_' + name)
+                    pred_real, pred_fake = \
+                        self.discrminate_image(net_D, None,
+                                               real_crop, fake_crop)
+                else:
+                    pred_real = pred_fake = None
+                output[name] = dict()
+                output[name]['pred_real'] = pred_real
+                output[name]['pred_fake'] = pred_fake
+
+        # Temporal loss.
+        past_frames, skipped_frames = \
+            get_all_skipped_frames(past_frames, [real_image, fake_image],
+                                   self.num_scales, self.num_frames_D)
+
+        for scale in range(self.num_scales):
+            real_image, fake_image = \
+                [skipped_frame[scale] for skipped_frame in skipped_frames]
+            pred_real, pred_fake = self.discriminate_video(real_image,
+                                                           fake_image, scale)
+            output['temporal_%d' % scale] = dict()
+            output['temporal_%d' % scale]['pred_real'] = pred_real
+            output['temporal_%d' % scale]['pred_fake'] = pred_fake
+
+        return output, past_frames
+
+    def discrminate_image(self, net_D, real_A, real_B, fake_B):
+        r"""Discriminate individual images.
+
+        Args:
+            net_D (obj): Discriminator network.
+            real_A (NxC1xHxW tensor): Input label map.
+            real_B (NxC2xHxW tensor): Real image.
+            fake_B (NxC2xHxW tensor): Fake image.
+        Returns:
+            (tuple):
+              - pred_real (NxC3xH2xW2 tensor): Output of net_D for real images.
+              - pred_fake (NxC3xH2xW2 tensor): Output of net_D for fake images.
+        """
+        if real_A is not None:
+            real_AB = torch.cat([real_A, real_B], dim=1)
+            fake_AB = torch.cat([real_A, fake_B], dim=1)
+        else:
+            real_AB, fake_AB = real_B, fake_B
+
+        pred_real = net_D.forward(real_AB)
+        pred_fake = net_D.forward(fake_AB)
+        return pred_real, pred_fake
+
+    def discriminate_video(self, real_B, fake_B, scale):
+        r"""Discriminate a sequence of images.
+
+        Args:
+            real_B (NxCxHxW tensor): Real image.
+            fake_B (NxCxHxW tensor): Fake image.
+            scale (int): Temporal scale.
+        Returns:
+            (tuple):
+              - pred_real (NxC2xH2xW2 tensor): Output of net_D for real images.
+              - pred_fake (NxC2xH2xW2 tensor): Output of net_D for fake images.
+        """
+        if real_B is None:
+            return None, None
+        net_DT = getattr(self, 'net_DT%d' % scale)
+        height, width = real_B.shape[-2:]
+        real_B = real_B.view(-1, self.num_netDT_input_channels, height, width)
+        fake_B = fake_B.view(-1, self.num_netDT_input_channels, height, width)
+
+        pred_real = net_DT.forward(real_B)
+        pred_fake = net_DT.forward(fake_B)
+        return pred_real, pred_fake
+
+
+def get_all_skipped_frames(past_frames, new_frames, t_scales, tD):
+    r"""Get temporally skipped frames from the input frames.
+
+    Args:
+        past_frames (list of tensors): Past real frames / generator outputs.
+        new_frames (list of tensors): Current real frame / generated output.
+        t_scales (int): Temporal scale.
+        tD (int): Number of frames as input to the temporal discriminator.
+    Returns:
+        (tuple):
+          - new_past_frames (list of tensors): Past + current frames.
+          - skipped_frames (list of tensors): Temporally skipped frames using
+            the given t_scales.
+    """
+    new_past_frames, skipped_frames = [], []
+    for past_frame, new_frame in zip(past_frames, new_frames):
+        skipped_frame = None
+        if t_scales > 0:
+            past_frame, skipped_frame = \
+                get_skipped_frames(past_frame, new_frame.unsqueeze(1),
+                                   t_scales, tD)
+        new_past_frames.append(past_frame)
+        skipped_frames.append(skipped_frame)
+    return new_past_frames, skipped_frames
+
+
+def get_skipped_frames(all_frames, frame, t_scales, tD):
+    r"""Get temporally skipped frames from the input frames.
+
+    Args:
+        all_frames (NxTxCxHxW tensor): All past frames.
+        frame (Nx1xCxHxW tensor): Current frame.
+        t_scales (int): Temporal scale.
+        tD (int): Number of frames as input to the temporal discriminator.
+    Returns:
+        (tuple):
+          - all_frames (NxTxCxHxW tensor): Past + current frames.
+          - skipped_frames (list of NxTxCxHxW tensors): Temporally skipped
+            frames.
+    """
+    all_frames = torch.cat([all_frames.detach(), frame], dim=1) \
+        if all_frames is not None else frame
+    skipped_frames = [None] * t_scales
+    for s in range(t_scales):
+        # Number of skipped frames between neighboring frames (e.g. 1, 3, 9,...)
+        t_step = tD ** s
+        # Number of frames the final triplet frames span before skipping
+        # (e.g., 2, 6, 18, ...).
+        t_span = t_step * (tD-1)
+        if all_frames.size(1) > t_span:
+            skipped_frames[s] = all_frames[:, -(t_span+1)::t_step].contiguous()
+
+    # Maximum number of past frames we need to keep track of.
+    max_num_prev_frames = (tD ** (t_scales-1)) * (tD-1)
+    # Remove past frames that are older than this number.
+    if all_frames.size()[1] > max_num_prev_frames:
+        all_frames = all_frames[:, -max_num_prev_frames:]
+    return all_frames, skipped_frames
+
+
+class MultiPatchDiscriminator(nn.Module):
+    r"""Multi-resolution patch discriminator.
+
+    Args:
+        dis_cfg (obj): Discriminator part of the yaml config file.
+        num_input_channels (int): Number of input channels.
+    """
+
+    def __init__(self, dis_cfg, num_input_channels):
+        super(MultiPatchDiscriminator, self).__init__()
+        kernel_size = getattr(dis_cfg, 'kernel_size', 4)
+        num_filters = getattr(dis_cfg, 'num_filters', 64)
+        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
+        num_discriminators = getattr(dis_cfg, 'num_discriminators', 3)
+        num_layers = getattr(dis_cfg, 'num_layers', 3)
+        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
+        weight_norm_type = getattr(dis_cfg, 'weight_norm_type',
+                                   'spectral_norm')
+        self.nets_discriminator = []
+        for i in range(num_discriminators):
+            net_discriminator = NLayerPatchDiscriminator(
+                kernel_size,
+                num_input_channels,
+                num_filters,
+                num_layers,
+                max_num_filters,
+                activation_norm_type,
+                weight_norm_type)
+            self.add_module('discriminator_%d' % i, net_discriminator)
+            self.nets_discriminator.append(net_discriminator)
+
+    def forward(self, input_x):
+        r"""Multi-resolution patch discriminator forward.
+
+        Args:
+            input_x (N x C x H x W tensor) : Concatenation of images and
+                semantic representations.
+        Returns:
+            (dict):
+              - output (list): list of output tensors produced by individual
+                patch discriminators.
+              - features (list): list of lists of features produced by
+                individual patch discriminators.
+        """
+        output_list = []
+        features_list = []
+        input_downsampled = input_x
+        for name, net_discriminator in self.named_children():
+            if not name.startswith('discriminator_'):
+                continue
+            output, features = net_discriminator(input_downsampled)
+            output_list.append(output)
+            features_list.append(features)
+            input_downsampled = F.interpolate(
+                input_downsampled, scale_factor=0.5, mode='bilinear',
+                align_corners=True, recompute_scale_factor=True)
+        output_x = dict()
+        output_x['output'] = output_list
+        output_x['features'] = features_list
+        return output_x
diff --git a/imaginaire/discriminators/funit.py b/imaginaire/discriminators/funit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e74238bf723058dfb680989a8839ce54bf98e61
--- /dev/null
+++ b/imaginaire/discriminators/funit.py
@@ -0,0 +1,117 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+
+import torch
+from torch import nn
+
+from imaginaire.layers import Conv2dBlock, Res2dBlock
+
+
+class Discriminator(nn.Module):
+    r"""Discriminator in the improved FUNIT baseline in the COCO-FUNIT paper.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super().__init__()
+        self.model = ResDiscriminator(**vars(dis_cfg))
+
+    def forward(self, data, net_G_output, recon=True):
+        r"""Improved FUNIT discriminator forward function.
+
+        Args:
+            data (dict): Training data at the current iteration.
+            net_G_output (dict): Fake data generated at the current iteration.
+            recon (bool): If ``True``, also classifies reconstructed images.
+        """
+        source_labels = data['labels_content']
+        target_labels = data['labels_style']
+        fake_out_trans, fake_features_trans = \
+            self.model(net_G_output['images_trans'], target_labels)
+        output = dict(fake_out_trans=fake_out_trans,
+                      fake_features_trans=fake_features_trans)
+
+        real_out_style, real_features_style = \
+            self.model(data['images_style'], target_labels)
+        output.update(dict(real_out_style=real_out_style,
+                           real_features_style=real_features_style))
+        if recon:
+            fake_out_recon, fake_features_recon = \
+                self.model(net_G_output['images_recon'], source_labels)
+            output.update(dict(fake_out_recon=fake_out_recon,
+                               fake_features_recon=fake_features_recon))
+        return output
+
+
+class ResDiscriminator(nn.Module):
+    r"""Residual discriminator architecture used in the FUNIT paper."""
+
+    def __init__(self,
+                 image_channels=3,
+                 num_classes=119,
+                 num_filters=64,
+                 max_num_filters=1024,
+                 num_layers=6,
+                 padding_mode='reflect',
+                 weight_norm_type='',
+                 **kwargs):
+        super().__init__()
+        for key in kwargs:
+            if key != 'type':
+                warnings.warn(
+                    "Discriminator argument {} is not used".format(key))
+
+        conv_params = dict(padding_mode=padding_mode,
+                           activation_norm_type='none',
+                           weight_norm_type=weight_norm_type,
+                           bias=[True, True, True],
+                           nonlinearity='leakyrelu',
+                           order='NACNAC')
+
+        first_kernel_size = 7
+        first_padding = (first_kernel_size - 1) // 2
+        model = [Conv2dBlock(image_channels, num_filters,
+                             first_kernel_size, 1, first_padding,
+                             padding_mode=padding_mode,
+                             weight_norm_type=weight_norm_type)]
+        for i in range(num_layers):
+            num_filters_prev = num_filters
+            num_filters = min(num_filters * 2, max_num_filters)
+            model += [Res2dBlock(num_filters_prev, num_filters_prev,
+                                 **conv_params),
+                      Res2dBlock(num_filters_prev, num_filters,
+                                 **conv_params)]
+            if i != num_layers - 1:
+                model += [nn.ReflectionPad2d(1),
+                          nn.AvgPool2d(3, stride=2)]
+        self.model = nn.Sequential(*model)
+        self.classifier = Conv2dBlock(num_filters, 1, 1, 1, 0,
+                                      nonlinearity='leakyrelu',
+                                      weight_norm_type=weight_norm_type,
+                                      order='NACNAC')
+
+        self.embedder = nn.Embedding(num_classes, num_filters)
+
+    def forward(self, images, labels=None):
+        r"""Forward function of the projection discriminator.
+
+        Args:
+            images (image tensor): Images inputted to the discriminator.
+            labels (long int tensor): Class labels of the images.
+        """
+        assert (images.size(0) == labels.size(0))
+        features = self.model(images)
+        outputs = self.classifier(features)
+        features_1x1 = features.mean(3).mean(2)
+        if labels is None:
+            return features_1x1
+        embeddings = self.embedder(labels)
+        outputs += torch.sum(embeddings * features_1x1, dim=1,
+                             keepdim=True).view(images.size(0), 1, 1, 1)
+        return outputs, features_1x1
diff --git a/imaginaire/discriminators/gancraft.py b/imaginaire/discriminators/gancraft.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bc070cb46ac5c6ae287231ddd0144bedd6d55a2
--- /dev/null
+++ b/imaginaire/discriminators/gancraft.py
@@ -0,0 +1,278 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import functools
+from imaginaire.layers import Conv2dBlock
+
+from imaginaire.utils.data import get_paired_input_label_channel_number, get_paired_input_image_channel_number
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Discriminator(nn.Module):
+    r"""Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super(Discriminator, self).__init__()
+        # We assume the first datum is the ground truth image.
+        image_channels = get_paired_input_image_channel_number(data_cfg)
+        # Calculate number of channels in the input label.
+        num_labels = get_paired_input_label_channel_number(data_cfg)
+
+        self.use_label = getattr(dis_cfg, 'use_label', True)
+        # Override number of input channels
+        if hasattr(dis_cfg, 'image_channels'):
+            image_channels = dis_cfg.image_channels
+        if hasattr(dis_cfg, 'num_labels'):
+            num_labels = dis_cfg.num_labels
+        else:
+            # We assume the first datum is the ground truth image.
+            image_channels = get_paired_input_image_channel_number(data_cfg)
+            # Calculate number of channels in the input label.
+            num_labels = get_paired_input_label_channel_number(data_cfg)
+
+        if not self.use_label:
+            num_labels = 2  # ignore + true
+
+        # Build the discriminator.
+        num_filters = getattr(dis_cfg, 'num_filters', 128)
+        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
+
+        fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
+        fpse_activation_norm_type = getattr(dis_cfg,
+                                            'fpse_activation_norm_type',
+                                            'none')
+        do_multiscale = getattr(dis_cfg, 'do_multiscale', False)
+        smooth_resample = getattr(dis_cfg, 'smooth_resample', False)
+        no_label_except_largest_scale = getattr(dis_cfg, 'no_label_except_largest_scale', False)
+
+        self.fpse_discriminator = FPSEDiscriminator(
+            image_channels,
+            num_labels,
+            num_filters,
+            fpse_kernel_size,
+            weight_norm_type,
+            fpse_activation_norm_type,
+            do_multiscale,
+            smooth_resample,
+            no_label_except_largest_scale)
+
+    def _single_forward(self, input_label, input_image, weights):
+        output_list, features_list = self.fpse_discriminator(input_image, input_label, weights)
+        return output_list, [features_list]
+
+    def forward(self, data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False):
+        r"""GANcraft discriminator forward.
+
+        Args:
+            data (dict):
+              - data  (N x C1 x H x W tensor) : Ground truth images.
+              - label (N x C2 x H x W tensor) : Semantic representations.
+              - z (N x style_dims tensor): Gaussian random noise.
+            net_G_output (dict):
+              - fake_images  (N x C1 x H x W tensor) : Fake images.
+        Returns:
+            output_x (dict):
+              - real_outputs (list): list of output tensors produced by
+                individual patch discriminators for real images.
+              - real_features (list): list of lists of features produced by
+                individual patch discriminators for real images.
+              - fake_outputs (list): list of output tensors produced by
+                individual patch discriminators for fake images.
+              - fake_features (list): list of lists of features produced by
+                individual patch discriminators for fake images.
+        """
+        output_x = dict()
+
+        # Fake.
+        fake_images = net_G_output['fake_images']
+        if self.use_label:
+            fake_labels = data['fake_masks']
+        else:
+            fake_labels = torch.zeros([fake_images.size(0), 2, fake_images.size(
+                2), fake_images.size(3)], device=fake_images.device, dtype=fake_images.dtype)
+            fake_labels[:, 1, :, :] = 1
+        output_x['fake_outputs'], output_x['fake_features'] = \
+            self._single_forward(fake_labels, fake_images, None)
+
+        # Real.
+        if incl_real:
+            real_images = data['images']
+            if self.use_label:
+                real_labels = data['real_masks']
+            else:
+                real_labels = torch.zeros([real_images.size(0), 2, real_images.size(
+                    2), real_images.size(3)], device=real_images.device, dtype=real_images.dtype)
+                real_labels[:, 1, :, :] = 1
+            output_x['real_outputs'], output_x['real_features'] = \
+                self._single_forward(real_labels, real_images, None)
+
+        # pseudo-Real.
+        if incl_pseudo_real:
+            preal_images = data['pseudo_real_img']
+            preal_labels = data['fake_masks']
+            if not self.use_label:
+                preal_labels = torch.zeros([preal_images.size(0), 2, preal_images.size(
+                    2), preal_images.size(3)], device=preal_images.device, dtype=preal_images.dtype)
+                preal_labels[:, 1, :, :] = 1
+            output_x['pseudo_real_outputs'], output_x['pseudo_real_features'] = \
+                self._single_forward(preal_labels, preal_images, None)
+
+        return output_x
+
+
+class FPSEDiscriminator(nn.Module):
+    def __init__(self,
+                 num_input_channels,
+                 num_labels,
+                 num_filters,
+                 kernel_size,
+                 weight_norm_type,
+                 activation_norm_type,
+                 do_multiscale,
+                 smooth_resample,
+                 no_label_except_largest_scale):
+        super().__init__()
+
+        self.do_multiscale = do_multiscale
+        self.no_label_except_largest_scale = no_label_except_largest_scale
+
+        padding = int(np.ceil((kernel_size - 1.0) / 2))
+        nonlinearity = 'leakyrelu'
+        stride1_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              stride=1,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              # inplace_nonlinearity=True,
+                              order='CNA')
+        down_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              stride=2,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              # inplace_nonlinearity=True,
+                              order='CNA')
+        latent_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=1,
+                              stride=1,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              # inplace_nonlinearity=True,
+                              order='CNA')
+        # bottom-up pathway
+        self.enc1 = down_conv2d_block(num_input_channels, num_filters)  # 3
+        self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters)  # 7
+        self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters)  # 15
+        self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters)  # 31
+        self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters)  # 63
+
+        # top-down pathway
+        # self.lat1 = latent_conv2d_block(num_filters, 2 * num_filters) # Zekun
+        self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
+        self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
+        self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
+        self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
+
+        # upsampling
+        self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear',
+                                      align_corners=False)
+
+        # final layers
+        self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+        self.output = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
+
+        if self.do_multiscale:
+            self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+            self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+            if self.no_label_except_largest_scale:
+                self.output3 = Conv2dBlock(num_filters * 2, 2, kernel_size=1)
+                self.output4 = Conv2dBlock(num_filters * 2, 2, kernel_size=1)
+            else:
+                self.output3 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
+                self.output4 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
+
+        self.interpolator = functools.partial(F.interpolate, mode='nearest')
+        if smooth_resample:
+            self.interpolator = self.smooth_interp
+
+    @staticmethod
+    def smooth_interp(x, size):
+        r"""Smooth interpolation of segmentation maps.
+
+        Args:
+            x (4D tensor): Segmentation maps.
+            size(2D list): Target size (H, W).
+        """
+        x = F.interpolate(x, size=size, mode='area')
+        onehot_idx = torch.argmax(x, dim=-3, keepdims=True)
+        x.fill_(0.0)
+        x.scatter_(1, onehot_idx, 1.0)
+        return x
+
+    # Weights: [N C]
+    def forward(self, images, segmaps, weights=None):
+        # Assume images 256x256
+        # bottom-up pathway
+        feat11 = self.enc1(images)  # 128
+        feat12 = self.enc2(feat11)  # 64
+        feat13 = self.enc3(feat12)  # 32
+        feat14 = self.enc4(feat13)  # 16
+        feat15 = self.enc5(feat14)  # 8
+        # top-down pathway and lateral connections
+        feat25 = self.lat5(feat15)  # 8
+        feat24 = self.upsample2x(feat25) + self.lat4(feat14)  # 16
+        feat23 = self.upsample2x(feat24) + self.lat3(feat13)  # 32
+        feat22 = self.upsample2x(feat23) + self.lat2(feat12)  # 64
+
+        # final prediction layers
+        feat32 = self.final2(feat22)
+
+        results = []
+        label_map = self.interpolator(segmaps, size=feat32.size()[2:])
+        pred2 = self.output(feat32)  # N, num_labels+1, H//4, W//4
+
+        features = [feat11, feat12, feat13, feat14, feat15, feat25, feat24, feat23, feat22]
+        if weights is not None:
+            label_map = label_map * weights[..., None, None]
+        results.append({'pred': pred2, 'label': label_map})
+
+        if self.do_multiscale:
+            feat33 = self.final3(feat23)
+            pred3 = self.output3(feat33)
+
+            feat34 = self.final4(feat24)
+            pred4 = self.output4(feat34)
+
+            if self.no_label_except_largest_scale:
+                label_map3 = torch.ones([pred3.size(0), 1, pred3.size(2), pred3.size(3)], device=pred3.device)
+                label_map4 = torch.ones([pred4.size(0), 1, pred4.size(2), pred4.size(3)], device=pred4.device)
+            else:
+                label_map3 = self.interpolator(segmaps, size=pred3.size()[2:])
+                label_map4 = self.interpolator(segmaps, size=pred4.size()[2:])
+
+            if weights is not None:
+                label_map3 = label_map3 * weights[..., None, None]
+                label_map4 = label_map4 * weights[..., None, None]
+
+            results.append({'pred': pred3, 'label': label_map3})
+            results.append({'pred': pred4, 'label': label_map4})
+
+        return results, features
diff --git a/imaginaire/discriminators/mlp_multiclass.py b/imaginaire/discriminators/mlp_multiclass.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f7d1d27e2d21a7b4fd5f23646b545a44a47a783
--- /dev/null
+++ b/imaginaire/discriminators/mlp_multiclass.py
@@ -0,0 +1,63 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+
+import numpy as np
+import torch.nn as nn
+
+from imaginaire.layers import LinearBlock
+
+
+class Discriminator(nn.Module):
+    r"""Multi-layer Perceptron Classifier constructor.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super(Discriminator, self).__init__()
+        num_input_channels = dis_cfg.input_dims
+        num_labels = dis_cfg.num_labels
+        num_layers = getattr(dis_cfg, 'num_layers', 5)
+        num_filters = getattr(dis_cfg, 'num_filters', 512)
+        activation_norm_type = getattr(dis_cfg,
+                                       'activation_norm_type',
+                                       'batch_norm')
+        nonlinearity = getattr(dis_cfg, 'nonlinearity', 'leakyrelu')
+        base_linear_block = \
+            functools.partial(LinearBlock,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              order='CNA')
+        dropout_ratio = 0.1
+        layers = [base_linear_block(num_input_channels, num_filters),
+                  nn.Dropout(dropout_ratio)]
+        for n in range(num_layers):
+            dropout_ratio *= 1.5
+            dropout_ratio = np.min([dropout_ratio, 0.5])
+            layers += [base_linear_block(num_filters, num_filters),
+                       nn.Dropout(dropout_ratio)]
+        layers += [LinearBlock(num_filters, num_labels)]
+        self.model = nn.Sequential(*layers)
+
+    def forward(self, data):
+        r"""Patch Discriminator forward.
+
+        Args:
+            data (dict):
+              - data (N x -1 tensor): We will reshape the tensor to this format.
+        Returns:
+            (dict):
+              - results (N x C tensor): Output scores before softmax.
+        """
+        input_x = data['data']
+        bs = input_x.size()[0]
+        input_x = input_x.view(bs, -1)
+        pre_softmax_scores = self.model(input_x)
+        outputs = dict()
+        outputs['results'] = pre_softmax_scores
+        return outputs
diff --git a/imaginaire/discriminators/multires_patch.py b/imaginaire/discriminators/multires_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a31d25c689a3d5b26a3b92216715d3ebd62dc91
--- /dev/null
+++ b/imaginaire/discriminators/multires_patch.py
@@ -0,0 +1,313 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# Copyright (C) 2020 NVIDIA Corporation.  All rights reserved
+import functools
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from imaginaire.layers import Conv2dBlock
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+                                   get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Discriminator(nn.Module):
+    r"""Multi-resolution patch discriminator.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config
+            file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super(Discriminator, self).__init__()
+        print('Multi-resolution patch discriminator initialization.')
+        # We assume the first datum is the ground truth image.
+        image_channels = get_paired_input_image_channel_number(data_cfg)
+        # Calculate number of channels in the input label.
+        num_labels = get_paired_input_label_channel_number(data_cfg)
+
+        # Build the discriminator.
+        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
+        num_filters = getattr(dis_cfg, 'num_filters', 128)
+        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
+        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
+        num_layers = getattr(dis_cfg, 'num_layers', 5)
+        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
+        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
+        print('\tBase filter number: %d' % num_filters)
+        print('\tNumber of discriminators: %d' % num_discriminators)
+        print('\tNumber of layers in a discriminator: %d' % num_layers)
+        print('\tWeight norm type: %s' % weight_norm_type)
+        num_input_channels = image_channels + num_labels
+        self.model = MultiResPatchDiscriminator(num_discriminators,
+                                                kernel_size,
+                                                num_input_channels,
+                                                num_filters,
+                                                num_layers,
+                                                max_num_filters,
+                                                activation_norm_type,
+                                                weight_norm_type)
+        print('Done with the Multi-resolution patch '
+              'discriminator initialization.')
+
+    def forward(self, data, net_G_output, real=True):
+        r"""SPADE Generator forward.
+
+        Args:
+            data (dict):
+              - data  (N x C1 x H x W tensor) : Ground truth images.
+              - label (N x C2 x H x W tensor) : Semantic representations.
+              - z (N x style_dims tensor): Gaussian random noise.
+            net_G_output (dict):
+                fake_images  (N x C1 x H x W tensor) : Fake images.
+            real (bool): If ``True``, also classifies real images. Otherwise it
+                only classifies generated images to save computation during the
+                generator update.
+        Returns:
+            (tuple):
+              - real_outputs (list): list of output tensors produced by
+              - individual patch discriminators for real images.
+              - real_features (list): list of lists of features produced by
+                individual patch discriminators for real images.
+              - fake_outputs (list): list of output tensors produced by
+                individual patch discriminators for fake images.
+              - fake_features (list): list of lists of features produced by
+                individual patch discriminators for fake images.
+        """
+        output_x = dict()
+        if 'label' in data:
+            fake_input_x = torch.cat(
+                (data['label'], net_G_output['fake_images']), 1)
+        else:
+            fake_input_x = net_G_output['fake_images']
+        output_x['fake_outputs'], output_x['fake_features'], _ = \
+            self.model.forward(fake_input_x)
+        if real:
+            if 'label' in data:
+                real_input_x = torch.cat(
+                    (data['label'], data['images']), 1)
+            else:
+                real_input_x = data['images']
+            output_x['real_outputs'], output_x['real_features'], _ = \
+                self.model.forward(real_input_x)
+        return output_x
+
+
+class MultiResPatchDiscriminator(nn.Module):
+    r"""Multi-resolution patch discriminator.
+
+    Args:
+        num_discriminators (int): Num. of discriminators (one per scale).
+        kernel_size (int): Convolution kernel size.
+        num_image_channels (int): Num. of channels in the real/fake image.
+        num_filters (int): Num. of base filters in a layer.
+        num_layers (int): Num. of layers for the patch discriminator.
+        max_num_filters (int): Maximum num. of filters in a layer.
+        activation_norm_type (str): batch_norm/instance_norm/none/....
+        weight_norm_type (str): none/spectral_norm/weight_norm
+    """
+
+    def __init__(self,
+                 num_discriminators=3,
+                 kernel_size=3,
+                 num_image_channels=3,
+                 num_filters=64,
+                 num_layers=4,
+                 max_num_filters=512,
+                 activation_norm_type='',
+                 weight_norm_type='',
+                 **kwargs):
+        super().__init__()
+        for key in kwargs:
+            if key != 'type' and key != 'patch_wise':
+                warnings.warn(
+                    "Discriminator argument {} is not used".format(key))
+
+        self.discriminators = nn.ModuleList()
+        for i in range(num_discriminators):
+            net_discriminator = NLayerPatchDiscriminator(
+                kernel_size,
+                num_image_channels,
+                num_filters,
+                num_layers,
+                max_num_filters,
+                activation_norm_type,
+                weight_norm_type)
+            self.discriminators.append(net_discriminator)
+        print('Done with the Multi-resolution patch '
+              'discriminator initialization.')
+
+    def forward(self, input_x):
+        r"""Multi-resolution patch discriminator forward.
+
+        Args:
+            input_x (tensor) : Input images.
+        Returns:
+            (tuple):
+              - output_list (list): list of output tensors produced by
+                individual patch discriminators.
+              - features_list (list): list of lists of features produced by
+                individual patch discriminators.
+              - input_list (list): list of downsampled input images.
+        """
+        input_list = []
+        output_list = []
+        features_list = []
+        input_downsampled = input_x
+        for net_discriminator in self.discriminators:
+            input_list.append(input_downsampled)
+            output, features = net_discriminator(input_downsampled)
+            output_list.append(output)
+            features_list.append(features)
+            input_downsampled = nn.functional.interpolate(
+                input_downsampled, scale_factor=0.5, mode='bilinear',
+                align_corners=True, recompute_scale_factor=True)
+        return output_list, features_list, input_list
+
+
+class WeightSharedMultiResPatchDiscriminator(nn.Module):
+    r"""Multi-resolution patch discriminator with shared weights.
+
+    Args:
+        num_discriminators (int): Num. of discriminators (one per scale).
+        kernel_size (int): Convolution kernel size.
+        num_image_channels (int): Num. of channels in the real/fake image.
+        num_filters (int): Num. of base filters in a layer.
+        num_layers (int): Num. of layers for the patch discriminator.
+        max_num_filters (int): Maximum num. of filters in a layer.
+        activation_norm_type (str): batch_norm/instance_norm/none/....
+        weight_norm_type (str): none/spectral_norm/weight_norm
+    """
+
+    def __init__(self,
+                 num_discriminators=3,
+                 kernel_size=3,
+                 num_image_channels=3,
+                 num_filters=64,
+                 num_layers=4,
+                 max_num_filters=512,
+                 activation_norm_type='',
+                 weight_norm_type='',
+                 **kwargs):
+        super().__init__()
+        for key in kwargs:
+            if key != 'type' and key != 'patch_wise':
+                warnings.warn(
+                    "Discriminator argument {} is not used".format(key))
+        self.num_discriminators = num_discriminators
+        self.discriminator = NLayerPatchDiscriminator(
+            kernel_size,
+            num_image_channels,
+            num_filters,
+            num_layers,
+            max_num_filters,
+            activation_norm_type,
+            weight_norm_type)
+        print('Done with the Weight-Shared Multi-resolution patch '
+              'discriminator initialization.')
+
+    def forward(self, input_x):
+        r"""Multi-resolution patch discriminator forward.
+
+        Args:
+            input_x (tensor) : Input images.
+        Returns:
+            (tuple):
+              - output_list (list): list of output tensors produced by
+                individual patch discriminators.
+              - features_list (list): list of lists of features produced by
+                individual patch discriminators.
+              - input_list (list): list of downsampled input images.
+        """
+        input_list = []
+        output_list = []
+        features_list = []
+        input_downsampled = input_x
+        for i in range(self.num_discriminators):
+            input_list.append(input_downsampled)
+            output, features = self.discriminator(input_downsampled)
+            output_list.append(output)
+            features_list.append(features)
+            input_downsampled = nn.functional.interpolate(
+                input_downsampled, scale_factor=0.5, mode='bilinear',
+                align_corners=True)
+        return output_list, features_list, input_list
+
+
+class NLayerPatchDiscriminator(nn.Module):
+    r"""Patch Discriminator constructor.
+
+    Args:
+        kernel_size (int): Convolution kernel size.
+        num_input_channels (int): Num. of channels in the real/fake image.
+        num_filters (int): Num. of base filters in a layer.
+        num_layers (int): Num. of layers for the patch discriminator.
+        max_num_filters (int): Maximum num. of filters in a layer.
+        activation_norm_type (str): batch_norm/instance_norm/none/....
+        weight_norm_type (str): none/spectral_norm/weight_norm
+    """
+
+    def __init__(self,
+                 kernel_size,
+                 num_input_channels,
+                 num_filters,
+                 num_layers,
+                 max_num_filters,
+                 activation_norm_type,
+                 weight_norm_type):
+        super(NLayerPatchDiscriminator, self).__init__()
+        self.num_layers = num_layers
+        padding = int(np.floor((kernel_size - 1.0) / 2))
+        nonlinearity = 'leakyrelu'
+        base_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              # inplace_nonlinearity=True,
+                              order='CNA')
+        layers = [[base_conv2d_block(
+            num_input_channels, num_filters, stride=2)]]
+        for n in range(num_layers):
+            num_filters_prev = num_filters
+            num_filters = min(num_filters * 2, max_num_filters)
+            stride = 2 if n < (num_layers - 1) else 1
+            layers += [[base_conv2d_block(num_filters_prev, num_filters,
+                                          stride=stride)]]
+        layers += [[Conv2dBlock(num_filters, 1,
+                                3, 1,
+                                padding,
+                                weight_norm_type=weight_norm_type)]]
+        for n in range(len(layers)):
+            setattr(self, 'layer' + str(n), nn.Sequential(*layers[n]))
+
+    def forward(self, input_x):
+        r"""Patch Discriminator forward.
+
+        Args:
+            input_x (N x C x H1 x W2 tensor): Concatenation of images and
+                semantic representations.
+        Returns:
+            (tuple):
+              - output (N x 1 x H2 x W2 tensor): Discriminator output value.
+                Before the sigmoid when using NSGAN.
+              - features (list): lists of tensors of the intermediate
+                activations.
+        """
+        res = [input_x]
+        for n in range(self.num_layers + 2):
+            layer = getattr(self, 'layer' + str(n))
+            x = res[-1]
+            res.append(layer(x))
+        output = res[-1]
+        features = res[1:-1]
+        return output, features
diff --git a/imaginaire/discriminators/multires_patch_pano.py b/imaginaire/discriminators/multires_patch_pano.py
new file mode 100644
index 0000000000000000000000000000000000000000..97763da3bbfc990be1755bbb62383869fd6708da
--- /dev/null
+++ b/imaginaire/discriminators/multires_patch_pano.py
@@ -0,0 +1,247 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# Copyright (C) 2020 NVIDIA Corporation.  All rights reserved
+import functools
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from imaginaire.layers import Conv2dBlock
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+                                   get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+from model.sample import Equirectangular
+
+class Discriminator(nn.Module):
+    r"""Multi-resolution patch discriminator.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config
+            file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, dis_cfg):
+        super(Discriminator, self).__init__()
+        print('Multi-resolution patch discriminator initialization.')
+        # We assume the first datum is the ground truth image.
+        num_input_channels = getattr(dis_cfg, 'input_channels', 3)
+        # Calculate number of channels in the input label.
+
+        # Build the discriminator.
+        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
+        num_filters = getattr(dis_cfg, 'num_filters', 128)
+        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
+        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
+        num_layers = getattr(dis_cfg, 'num_layers', 5)
+        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
+        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
+        print('\tBase filter number: %d' % num_filters)
+        print('\tNumber of discriminators: %d' % num_discriminators)
+        print('\tNumber of layers in a discriminator: %d' % num_layers)
+        print('\tWeight norm type: %s' % weight_norm_type)
+        self.condition = getattr(dis_cfg, 'condition', None)
+        # self.condition = dis_cfg.condition
+        self.model = MultiResPatchDiscriminator(num_discriminators,
+                                                kernel_size,
+                                                num_input_channels,
+                                                num_filters,
+                                                num_layers,
+                                                max_num_filters,
+                                                activation_norm_type,
+                                                weight_norm_type)
+        print('Done with the Multi-resolution patch '
+              'discriminator initialization.')
+
+    def forward(self, data, net_G_output, real=True):
+        r"""SPADE Generator forward.
+
+        Args:
+            data  (N x C1 x H x W tensor) : Ground truth images.
+            net_G_output (dict):
+                fake_images  (N x C1 x H x W tensor) : Fake images.
+            real (bool): If ``True``, also classifies real images. Otherwise it
+                only classifies generated images to save computation during the
+                generator update.
+        Returns:
+            (tuple):
+              - real_outputs (list): list of output tensors produced by
+              - individual patch discriminators for real images.
+              - real_features (list): list of lists of features produced by
+                individual patch discriminators for real images.
+              - fake_outputs (list): list of output tensors produced by
+                individual patch discriminators for fake images.
+              - fake_features (list): list of lists of features produced by
+                individual patch discriminators for fake images.
+        """
+        output_x = dict()
+        if self.condition:
+            fake_input_x = torch.cat([net_G_output['pred'],net_G_output['generator_inputs']],dim=1)
+        else:
+            fake_input_x = net_G_output['pred']
+        output_x['fake_outputs'], output_x['fake_features'], _ = \
+            self.model.forward(fake_input_x)
+        if real:
+            if self.condition:
+                real_input_x = torch.cat([net_G_output['pred'],net_G_output['generator_inputs']],dim=1)
+            else:
+                real_input_x = data
+            output_x['real_outputs'], output_x['real_features'], _ = \
+                self.model.forward(real_input_x)
+        return output_x
+
+
+class MultiResPatchDiscriminator(nn.Module):
+    r"""Multi-resolution patch discriminator.
+
+    Args:
+        num_discriminators (int): Num. of discriminators (one per scale).
+        kernel_size (int): Convolution kernel size.
+        num_image_channels (int): Num. of channels in the real/fake image.
+        num_filters (int): Num. of base filters in a layer.
+        num_layers (int): Num. of layers for the patch discriminator.
+        max_num_filters (int): Maximum num. of filters in a layer.
+        activation_norm_type (str): batch_norm/instance_norm/none/....
+        weight_norm_type (str): none/spectral_norm/weight_norm
+    """
+
+    def __init__(self,
+                 num_discriminators=3,
+                 kernel_size=3,
+                 num_image_channels=3,
+                 num_filters=64,
+                 num_layers=4,
+                 max_num_filters=512,
+                 activation_norm_type='',
+                 weight_norm_type='',
+                 **kwargs):
+        super().__init__()
+        for key in kwargs:
+            if key != 'type' and key != 'patch_wise':
+                warnings.warn(
+                    "Discriminator argument {} is not used".format(key))
+
+        self.discriminators = nn.ModuleList()
+        for i in range(num_discriminators):
+            net_discriminator = NLayerPatchDiscriminator(
+                kernel_size,
+                num_image_channels,
+                num_filters,
+                num_layers,
+                max_num_filters,
+                activation_norm_type,
+                weight_norm_type)
+            self.discriminators.append(net_discriminator)
+        print('Done with the Multi-resolution patch '
+              'discriminator initialization.')
+        self.e = Equirectangular(theta=[-40., 40.],width = 128, height = 128,FovX = 100)
+
+    def forward(self, input_x):
+        r"""Multi-resolution patch discriminator forward.
+
+        Args:
+            input_x (tensor) : Input images.
+        Returns:
+            (tuple):
+              - output_list (list): list of output tensors produced by
+                individual patch discriminators.
+              - features_list (list): list of lists of features produced by
+                individual patch discriminators.
+              - input_list (list): list of downsampled input images.
+        """
+        input_list = []
+        output_list = []
+        features_list = []
+        input_N = nn.functional.interpolate(
+            input_x, scale_factor=0.5, mode='bilinear',
+            align_corners=True, recompute_scale_factor=True)
+        equ= self.e(input_x)
+        for i, net_discriminator in enumerate(self.discriminators):
+            input_list.append(input_N)
+            output, features = net_discriminator(input_N)
+            output_list.append(output)
+            features_list.append(features)
+            if i == 0:
+                input_N = torch.nn.functional.grid_sample(input_x, equ.float(), align_corners = True)*0.99
+            elif i == 1:
+                input_N = nn.functional.interpolate(
+                    input_N, scale_factor=0.5, mode='bilinear',
+                    align_corners=True, recompute_scale_factor=True)
+
+        return output_list, features_list, input_list
+
+class NLayerPatchDiscriminator(nn.Module):
+    r"""Patch Discriminator constructor.
+
+    Args:
+        kernel_size (int): Convolution kernel size.
+        num_input_channels (int): Num. of channels in the real/fake image.
+        num_filters (int): Num. of base filters in a layer.
+        num_layers (int): Num. of layers for the patch discriminator.
+        max_num_filters (int): Maximum num. of filters in a layer.
+        activation_norm_type (str): batch_norm/instance_norm/none/....
+        weight_norm_type (str): none/spectral_norm/weight_norm
+    """
+
+    def __init__(self,
+                 kernel_size,
+                 num_input_channels,
+                 num_filters,
+                 num_layers,
+                 max_num_filters,
+                 activation_norm_type,
+                 weight_norm_type):
+        super(NLayerPatchDiscriminator, self).__init__()
+        self.num_layers = num_layers
+        padding = int(np.floor((kernel_size - 1.0) / 2))
+        nonlinearity = 'leakyrelu'
+        base_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity,
+                              # inplace_nonlinearity=True,
+                              order='CNA')
+        layers = [[base_conv2d_block(
+            num_input_channels, num_filters, stride=2)]]
+        for n in range(num_layers):
+            num_filters_prev = num_filters
+            num_filters = min(num_filters * 2, max_num_filters)
+            stride = 2 if n < (num_layers - 1) else 1
+            layers += [[base_conv2d_block(num_filters_prev, num_filters,
+                                          stride=stride)]]
+        layers += [[Conv2dBlock(num_filters, 1,
+                                3, 1,
+                                padding,
+                                weight_norm_type=weight_norm_type)]]
+        for n in range(len(layers)):
+            setattr(self, 'layer' + str(n), nn.Sequential(*layers[n]))
+        
+
+    def forward(self, input_x):
+        r"""Patch Discriminator forward.
+
+        Args:
+            input_x (N x C x H1 x W2 tensor): Concatenation of images and
+                semantic representations.
+        Returns:
+            (tuple):
+              - output (N x 1 x H2 x W2 tensor): Discriminator output value.
+                Before the sigmoid when using NSGAN.
+              - features (list): lists of tensors of the intermediate
+                activations.
+        """
+        res = [input_x]
+        for n in range(self.num_layers + 2):
+            layer = getattr(self, 'layer' + str(n))
+            x = res[-1]
+            res.append(layer(x))
+        output = res[-1]
+        features = res[1:-1]
+        return output, features
diff --git a/imaginaire/discriminators/munit.py b/imaginaire/discriminators/munit.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e407569764dced0962e66c56a1a0b2e9106c683
--- /dev/null
+++ b/imaginaire/discriminators/munit.py
@@ -0,0 +1,99 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from torch import nn
+
+from imaginaire.discriminators.multires_patch import MultiResPatchDiscriminator
+from imaginaire.discriminators.residual import ResDiscriminator
+
+
+class Discriminator(nn.Module):
+    r"""MUNIT discriminator. It can be either a multi-resolution patch
+    discriminator like in the original implementation, or a
+    global residual discriminator.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super().__init__()
+        if getattr(dis_cfg, 'patch_wise', True):
+            # Use the multi-resolution patch discriminator. It works better for
+            # scene images and when you want to preserve pixel-wise
+            # correspondence during translation.
+            self.discriminator_a = \
+                MultiResPatchDiscriminator(**vars(dis_cfg))
+            self.discriminator_b = \
+                MultiResPatchDiscriminator(**vars(dis_cfg))
+        else:
+            # Use the global residual discriminator. It works better if images
+            # have a single centered object (e.g., animal faces, shoes).
+            self.discriminator_a = ResDiscriminator(**vars(dis_cfg))
+            self.discriminator_b = ResDiscriminator(**vars(dis_cfg))
+
+    def forward(self, data, net_G_output, gan_recon=False, real=True):
+        r"""Returns the output of the discriminator.
+
+        Args:
+            data (dict):
+              - images_a  (tensor) : Images in domain A.
+              - images_b  (tensor) : Images in domain B.
+            net_G_output (dict):
+              - images_ab  (tensor) : Images translated from domain A to B by
+                the generator.
+              - images_ba  (tensor) : Images translated from domain B to A by
+                the generator.
+              - images_aa  (tensor) : Reconstructed images in domain A.
+              - images_bb  (tensor) : Reconstructed images in domain B.
+            gan_recon (bool): If ``True``, also classifies reconstructed images.
+            real (bool): If ``True``, also classifies real images. Otherwise it
+                only classifies generated images to save computation during the
+                generator update.
+
+        Returns:
+            (dict):
+              - out_ab (tensor): Output of the discriminator for images
+                translated from domain A to B by the generator.
+              - out_ab (tensor): Output of the discriminator for images
+                translated from domain B to A by the generator.
+              - fea_ab (tensor): Intermediate features of the discriminator
+                for images translated from domain B to A by the generator.
+              - fea_ba (tensor): Intermediate features of the discriminator
+                for images translated from domain A to B by the generator.
+
+              - out_a (tensor): Output of the discriminator for images
+                in domain A.
+              - out_b (tensor): Output of the discriminator for images
+                in domain B.
+              - fea_a (tensor): Intermediate features of the discriminator
+                for images in domain A.
+              - fea_b (tensor): Intermediate features of the discriminator
+                for images in domain B.
+
+              - out_aa (tensor): Output of the discriminator for
+                reconstructed images in domain A.
+              - out_bb (tensor): Output of the discriminator for
+                reconstructed images in domain B.
+              - fea_aa (tensor): Intermediate features of the discriminator
+                for reconstructed images in domain A.
+              - fea_bb (tensor): Intermediate features of the discriminator
+                for reconstructed images in domain B.
+        """
+        out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab'])
+        out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba'])
+        output = dict(out_ba=out_ba, out_ab=out_ab,
+                      fea_ba=fea_ba, fea_ab=fea_ab)
+        if real:
+            out_a, fea_a, _ = self.discriminator_a(data['images_a'])
+            out_b, fea_b, _ = self.discriminator_b(data['images_b'])
+            output.update(dict(out_a=out_a, out_b=out_b,
+                               fea_a=fea_a, fea_b=fea_b))
+        if gan_recon:
+            out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa'])
+            out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb'])
+            output.update(dict(out_aa=out_aa, out_bb=out_bb,
+                               fea_aa=fea_aa, fea_bb=fea_bb))
+        return output
diff --git a/imaginaire/discriminators/residual.py b/imaginaire/discriminators/residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..f65b41df96f8ea25e82d18adf33cabef746f863c
--- /dev/null
+++ b/imaginaire/discriminators/residual.py
@@ -0,0 +1,96 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+
+import torch
+import torch.nn as nn
+
+from imaginaire.layers import Conv2dBlock, Res2dBlock
+from imaginaire.third_party.upfirdn2d import BlurDownsample
+
+
+class ResDiscriminator(nn.Module):
+    r"""Global residual discriminator.
+
+    Args:
+        image_channels (int): Num. of channels in the real/fake image.
+        num_filters (int): Num. of base filters in a layer.
+        max_num_filters (int): Maximum num. of filters in a layer.
+        first_kernel_size (int): Kernel size in the first layer.
+        num_layers (int): Num. of layers in discriminator.
+        padding_mode (str): Padding mode.
+        activation_norm_type (str): Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``.
+        weight_norm_type (str): Type of weight normalization.
+            ``'none'``, ``'spectral'``, or ``'weight'``.
+        aggregation (str): Method to aggregate features across different
+            locations in the final layer. ``'conv'``, or ``'pool'``.
+        order (str): Order of operations in the residual link.
+        anti_aliased (bool): If ``True``, uses anti-aliased pooling.
+    """
+
+    def __init__(self,
+                 image_channels=3,
+                 num_filters=64,
+                 max_num_filters=512,
+                 first_kernel_size=1,
+                 num_layers=4,
+                 padding_mode='zeros',
+                 activation_norm_type='',
+                 weight_norm_type='',
+                 aggregation='conv',
+                 order='pre_act',
+                 anti_aliased=False,
+                 **kwargs):
+        super().__init__()
+        for key in kwargs:
+            if key != 'type' and key != 'patch_wise':
+                warnings.warn(
+                    "Discriminator argument {} is not used".format(key))
+
+        conv_params = dict(padding_mode=padding_mode,
+                           activation_norm_type=activation_norm_type,
+                           weight_norm_type=weight_norm_type,
+                           nonlinearity='leakyrelu')
+
+        first_padding = (first_kernel_size - 1) // 2
+        model = [Conv2dBlock(image_channels, num_filters,
+                             first_kernel_size, 1, first_padding,
+                             **conv_params)]
+        for _ in range(num_layers):
+            num_filters_prev = num_filters
+            num_filters = min(num_filters * 2, max_num_filters)
+            model.append(Res2dBlock(num_filters_prev, num_filters, order=order,
+                                    **conv_params))
+            if anti_aliased:
+                model.append(BlurDownsample())
+            else:
+                model.append(nn.AvgPool2d(2, stride=2))
+        if aggregation == 'pool':
+            model += [torch.nn.AdaptiveAvgPool2d(1)]
+        elif aggregation == 'conv':
+            model += [Conv2dBlock(num_filters, num_filters, 4, 1, 0,
+                                  nonlinearity='leakyrelu')]
+        else:
+            raise ValueError('The aggregation mode is not recognized'
+                             % self.aggregation)
+        self.model = nn.Sequential(*model)
+        self.classifier = nn.Linear(num_filters, 1)
+
+    def forward(self, images):
+        r"""Multi-resolution patch discriminator forward.
+
+        Args:
+            images (tensor) : Input images.
+        Returns:
+            (tuple):
+              - outputs (tensor): Output of the discriminator.
+              - features (tensor): Intermediate features of the discriminator.
+              - images (tensor): Input images.
+        """
+        batch_size = images.size(0)
+        features = self.model(images)
+        outputs = self.classifier(features.view(batch_size, -1))
+        return outputs, features, images
diff --git a/imaginaire/discriminators/spade.py b/imaginaire/discriminators/spade.py
new file mode 100644
index 0000000000000000000000000000000000000000..d85d1c5926ce92e948501c7afaf1d1324e1ea38c
--- /dev/null
+++ b/imaginaire/discriminators/spade.py
@@ -0,0 +1,119 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+
+from imaginaire.discriminators.fpse import FPSEDiscriminator
+from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+                                   get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Discriminator(nn.Module):
+    r"""Multi-resolution patch discriminator.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super(Discriminator, self).__init__()
+        print('Multi-resolution patch discriminator initialization.')
+        image_channels = getattr(dis_cfg, 'image_channels', None)
+        if image_channels is None:
+            image_channels = get_paired_input_image_channel_number(data_cfg)
+        num_labels = getattr(dis_cfg, 'num_labels', None)
+        if num_labels is None:
+            # Calculate number of channels in the input label when not specified.
+            num_labels = get_paired_input_label_channel_number(data_cfg)
+
+        # Build the discriminator.
+        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
+        num_filters = getattr(dis_cfg, 'num_filters', 128)
+        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
+        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
+        num_layers = getattr(dis_cfg, 'num_layers', 5)
+        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
+        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
+        print('\tBase filter number: %d' % num_filters)
+        print('\tNumber of discriminators: %d' % num_discriminators)
+        print('\tNumber of layers in a discriminator: %d' % num_layers)
+        print('\tWeight norm type: %s' % weight_norm_type)
+        num_input_channels = image_channels + num_labels
+        self.discriminators = nn.ModuleList()
+        for i in range(num_discriminators):
+            net_discriminator = NLayerPatchDiscriminator(
+                kernel_size,
+                num_input_channels,
+                num_filters,
+                num_layers,
+                max_num_filters,
+                activation_norm_type,
+                weight_norm_type)
+            self.discriminators.append(net_discriminator)
+        print('Done with the Multi-resolution patch discriminator initialization.')
+        self.use_fpse = getattr(dis_cfg, 'use_fpse', True)
+        if self.use_fpse:
+            fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
+            fpse_activation_norm_type = getattr(dis_cfg,
+                                                'fpse_activation_norm_type',
+                                                'none')
+            self.fpse_discriminator = FPSEDiscriminator(
+                image_channels,
+                num_labels,
+                num_filters,
+                fpse_kernel_size,
+                weight_norm_type,
+                fpse_activation_norm_type)
+
+    def _single_forward(self, input_label, input_image):
+        # Compute discriminator outputs and intermediate features from input
+        # images and semantic labels.
+        input_x = torch.cat(
+            (input_label, input_image), 1)
+        output_list = []
+        features_list = []
+        if self.use_fpse:
+            pred2, pred3, pred4 = self.fpse_discriminator(input_image, input_label)
+            output_list = [pred2, pred3, pred4]
+        input_downsampled = input_x
+        for net_discriminator in self.discriminators:
+            output, features = net_discriminator(input_downsampled)
+            output_list.append(output)
+            features_list.append(features)
+            input_downsampled = nn.functional.interpolate(
+                input_downsampled, scale_factor=0.5, mode='bilinear',
+                align_corners=True)
+        return output_list, features_list
+
+    def forward(self, data, net_G_output):
+        r"""SPADE discriminator forward.
+
+        Args:
+            data (dict):
+              - data  (N x C1 x H x W tensor) : Ground truth images.
+              - label (N x C2 x H x W tensor) : Semantic representations.
+              - z (N x style_dims tensor): Gaussian random noise.
+            net_G_output (dict):
+                fake_images  (N x C1 x H x W tensor) : Fake images.
+        Returns:
+            (dict):
+              - real_outputs (list): list of output tensors produced by
+                individual patch discriminators for real images.
+              - real_features (list): list of lists of features produced by
+                individual patch discriminators for real images.
+              - fake_outputs (list): list of output tensors produced by
+                individual patch discriminators for fake images.
+              - fake_features (list): list of lists of features produced by
+                individual patch discriminators for fake images.
+        """
+        output_x = dict()
+        output_x['real_outputs'], output_x['real_features'] = \
+            self._single_forward(data['label'], data['images'])
+        output_x['fake_outputs'], output_x['fake_features'] = \
+            self._single_forward(data['label'], net_G_output['fake_images'])
+        return output_x
diff --git a/imaginaire/discriminators/unit.py b/imaginaire/discriminators/unit.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb537fe47ea17051a012f7e647c5ac6b6ea1b7d9
--- /dev/null
+++ b/imaginaire/discriminators/unit.py
@@ -0,0 +1,99 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from torch import nn
+
+from imaginaire.discriminators.multires_patch import \
+    WeightSharedMultiResPatchDiscriminator
+from imaginaire.discriminators.residual import ResDiscriminator
+
+
+class Discriminator(nn.Module):
+    r"""UNIT discriminator. It can be either a multi-resolution patch
+    discriminator like in the original implementation, or a
+    global residual discriminator.
+
+    Args:
+        dis_cfg (obj): Discriminator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file
+    """
+
+    def __init__(self, dis_cfg, data_cfg):
+        super().__init__()
+        if getattr(dis_cfg, 'patch_dis', True):
+            # Use the multi-resolution patch discriminator. It works better for
+            # scene images and when you want to preserve pixel-wise
+            # correspondence during translation.
+            self.discriminator_a = \
+                WeightSharedMultiResPatchDiscriminator(**vars(dis_cfg))
+            self.discriminator_b = \
+                WeightSharedMultiResPatchDiscriminator(**vars(dis_cfg))
+        else:
+            # Use the global residual discriminator. It works better if images
+            # have a single centered object (e.g., animal faces, shoes).
+            self.discriminator_a = ResDiscriminator(**vars(dis_cfg))
+            self.discriminator_b = ResDiscriminator(**vars(dis_cfg))
+
+    def forward(self, data, net_G_output, gan_recon=False, real=True):
+        r"""Returns the output of the discriminator.
+
+        Args:
+            data (dict):
+              - images_a  (tensor) : Images in domain A.
+              - images_b  (tensor) : Images in domain B.
+            net_G_output (dict):
+              - images_ab  (tensor) : Images translated from domain A to B by
+                the generator.
+              - images_ba  (tensor) : Images translated from domain B to A by
+                the generator.
+              - images_aa  (tensor) : Reconstructed images in domain A.
+              - images_bb  (tensor) : Reconstructed images in domain B.
+            gan_recon (bool): If ``True``, also classifies reconstructed images.
+            real (bool): If ``True``, also classifies real images. Otherwise it
+                only classifies generated images to save computation during the
+                generator update.
+        Returns:
+            (dict):
+              - out_ab (tensor): Output of the discriminator for images
+                translated from domain A to B by the generator.
+              - out_ab (tensor): Output of the discriminator for images
+                translated from domain B to A by the generator.
+              - fea_ab (tensor): Intermediate features of the discriminator
+                for images translated from domain B to A by the generator.
+              - fea_ba (tensor): Intermediate features of the discriminator
+                for images translated from domain A to B by the generator.
+
+              - out_a (tensor): Output of the discriminator for images
+                in domain A.
+              - out_b (tensor): Output of the discriminator for images
+                in domain B.
+              - fea_a (tensor): Intermediate features of the discriminator
+                for images in domain A.
+              - fea_b (tensor): Intermediate features of the discriminator
+                for images in domain B.
+
+              - out_aa (tensor): Output of the discriminator for
+                reconstructed images in domain A.
+              - out_bb (tensor): Output of the discriminator for
+                reconstructed images in domain B.
+              - fea_aa (tensor): Intermediate features of the discriminator
+                for reconstructed images in domain A.
+              - fea_bb (tensor): Intermediate features of the discriminator
+                for reconstructed images in domain B.
+        """
+        out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab'])
+        out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba'])
+        output = dict(out_ba=out_ba, out_ab=out_ab,
+                      fea_ba=fea_ba, fea_ab=fea_ab)
+        if real:
+            out_a, fea_a, _ = self.discriminator_a(data['images_a'])
+            out_b, fea_b, _ = self.discriminator_b(data['images_b'])
+            output.update(dict(out_a=out_a, out_b=out_b,
+                               fea_a=fea_a, fea_b=fea_b))
+        if gan_recon:
+            out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa'])
+            out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb'])
+            output.update(dict(out_aa=out_aa, out_bb=out_bb,
+                               fea_aa=fea_aa, fea_bb=fea_bb))
+        return output
diff --git a/imaginaire/evaluation/__init__.py b/imaginaire/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a379a7be63550921dcd5802e7621196df9de9e1
--- /dev/null
+++ b/imaginaire/evaluation/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .fid import compute_fid, compute_fid_data
+from .kid import compute_kid, compute_kid_data
+from .prdc import compute_prdc
+from .common import compute_all_metrics, compute_all_metrics_data
+
+__all__ = ['compute_fid', 'compute_fid_data', 'compute_kid', 'compute_kid_data',
+           'compute_prdc', 'compute_all_metrics', 'compute_all_metrics_data']
diff --git a/imaginaire/evaluation/caption/__init__.py b/imaginaire/evaluation/caption/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3197a51d0318c393a6659495d97914b1192f587a
--- /dev/null
+++ b/imaginaire/evaluation/caption/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .r_precision import get_r_precision
+from .common import get_image_encoder
+
+__all__ = ['get_image_encoder', 'get_r_precision']
diff --git a/imaginaire/evaluation/caption/clip.py b/imaginaire/evaluation/caption/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cdcd94e80afcb310564d9d670c3f73f5e065707
--- /dev/null
+++ b/imaginaire/evaluation/caption/clip.py
@@ -0,0 +1,576 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa
+# https://github.com/openai/CLIP
+import hashlib
+import os
+import urllib
+import warnings
+from time import sleep
+from typing import Union, List
+
+
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+from torch import nn
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, \
+    Normalize
+from tqdm import tqdm
+
+__all__ = ["available_models", "load", 'build_model']
+
+from imaginaire.utils.io import download_file_from_google_drive
+
+_MODELS = {
+    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+}
+
+
+def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+    os.makedirs(root, exist_ok=True)
+    filename = os.path.basename(url)
+
+    expected_sha256 = url.split("/")[-2]
+    download_target = os.path.join(root, filename)
+
+    if os.path.exists(download_target) and not os.path.isfile(download_target):
+        raise RuntimeError(
+            f"{download_target} exists and is not a regular file")
+
+    if os.path.isfile(download_target):
+        if hashlib.sha256(open(download_target,
+                               "rb").read()).hexdigest() == expected_sha256:
+            return download_target
+        else:
+            warnings.warn(
+                f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+    with urllib.request.urlopen(url) as source, open(download_target,
+                                                     "wb") as output:
+        with tqdm(total=int(source.info().get("Content-Length")), ncols=80,
+                  unit='iB', unit_scale=True) as loop:
+            while True:
+                buffer = source.read(8192)
+                if not buffer:
+                    break
+
+                output.write(buffer)
+                loop.update(len(buffer))
+
+    if hashlib.sha256(
+            open(download_target, "rb").read()).hexdigest() != expected_sha256:
+        raise RuntimeError(
+            f"Model has been downloaded but the SHA256 checksum does not not match")
+
+    return download_target
+
+
+def _transform(n_px):
+    return Compose([
+        Resize(n_px, interpolation=Image.BICUBIC),
+        CenterCrop(n_px),
+        lambda image: image.convert("RGB"),
+        ToTensor(),
+        Normalize((0.48145466, 0.4578275, 0.40821073),
+                  (0.26862954, 0.26130258, 0.27577711)),
+    ])
+
+
+def available_models() -> List[str]:
+    """Returns the names of available CLIP models"""
+    return list(_MODELS.keys())
+
+
+def load(model_path):
+    if not os.path.exists(model_path):
+        downloaded = False
+        while not downloaded:
+            try:
+                download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path)
+                downloaded = True
+            except Exception as e:
+                print(e)
+                sleep(30)
+                continue
+    model = torch.load(model_path, map_location='cpu')
+    model = build_model(model).cuda()
+    return model, _transform(model.visual.input_resolution)
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1):
+        super().__init__()
+
+        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+
+        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = None
+        self.stride = stride
+
+        if stride > 1 or inplanes != planes * Bottleneck.expansion:
+            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+            self.downsample = nn.Sequential(OrderedDict([
+                ("-1", nn.AvgPool2d(stride)),
+                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1,
+                                bias=False)),
+                ("1", nn.BatchNorm2d(planes * self.expansion))
+            ]))
+
+    def forward(self, x: torch.Tensor):
+        identity = x
+
+        out = self.relu(self.bn1(self.conv1(x)))
+        out = self.relu(self.bn2(self.conv2(out)))
+        out = self.avgpool(out)
+        out = self.bn3(self.conv3(out))
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+        return out
+
+
+class AttentionPool2d(nn.Module):
+    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int,
+                 output_dim: int = None):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(
+            torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+        self.k_proj = nn.Linear(embed_dim, embed_dim)
+        self.q_proj = nn.Linear(embed_dim, embed_dim)
+        self.v_proj = nn.Linear(embed_dim, embed_dim)
+        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+        self.num_heads = num_heads
+
+    def forward(self, x):
+        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
+            2, 0, 1)  # NCHW -> (HW)NC
+        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
+        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
+        x, _ = F.multi_head_attention_forward(
+            query=x, key=x, value=x,
+            embed_dim_to_check=x.shape[-1],
+            num_heads=self.num_heads,
+            q_proj_weight=self.q_proj.weight,
+            k_proj_weight=self.k_proj.weight,
+            v_proj_weight=self.v_proj.weight,
+            in_proj_weight=None,
+            in_proj_bias=torch.cat(
+                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+            bias_k=None,
+            bias_v=None,
+            add_zero_attn=False,
+            dropout_p=0,
+            out_proj_weight=self.c_proj.weight,
+            out_proj_bias=self.c_proj.bias,
+            use_separate_proj_weight=True,
+            training=self.training,
+            need_weights=False
+        )
+
+        return x[0]
+
+
+class ModifiedResNet(nn.Module):
+    """
+    A ResNet class that is similar to torchvision's but contains the following changes:
+    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+    - The final pooling layer is a QKV attention instead of an average pool
+    """
+
+    def __init__(self, layers, output_dim, heads, input_resolution=224,
+                 width=64):
+        super().__init__()
+        self.output_dim = output_dim
+        self.input_resolution = input_resolution
+
+        # the 3-layer stem
+        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2,
+                               padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(width // 2)
+        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1,
+                               bias=False)
+        self.bn2 = nn.BatchNorm2d(width // 2)
+        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1,
+                               bias=False)
+        self.bn3 = nn.BatchNorm2d(width)
+        self.avgpool = nn.AvgPool2d(2)
+        self.relu = nn.ReLU(inplace=True)
+
+        # residual layers
+        self._inplanes = width  # this is a *mutable* variable used during construction
+        self.layer1 = self._make_layer(width, layers[0])
+        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+        embed_dim = width * 32  # the ResNet feature dimension
+        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
+                                        heads, output_dim)
+
+    def _make_layer(self, planes, blocks, stride=1):
+        layers = [Bottleneck(self._inplanes, planes, stride)]
+
+        self._inplanes = planes * Bottleneck.expansion
+        for _ in range(1, blocks):
+            layers.append(Bottleneck(self._inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        def stem(x):
+            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
+                             (self.conv3, self.bn3)]:
+                x = self.relu(bn(conv(x)))
+            x = self.avgpool(x)
+            return x
+
+        x = x.type(self.conv1.weight.dtype)
+        x = stem(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.attnpool(x)
+
+        return x
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model: int, n_head: int,
+                 attn_mask: torch.Tensor = None):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("c_proj", nn.Linear(d_model * 4, d_model))
+        ]))
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+    def attention(self, x: torch.Tensor):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype,
+                                           device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[
+            0]
+
+    def forward(self, x: torch.Tensor):
+        x = x + self.attention(self.ln_1(x))
+        x = x + self.mlp(self.ln_2(x))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(self, width: int, layers: int, heads: int,
+                 attn_mask: torch.Tensor = None):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(
+            *[ResidualAttentionBlock(width, heads, attn_mask) for _ in
+              range(layers)])
+
+    def forward(self, x: torch.Tensor):
+        return self.resblocks(x)
+
+
+class VisualTransformer(nn.Module):
+    def __init__(self, input_resolution: int, patch_size: int, width: int,
+                 layers: int, heads: int, output_dim: int):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.output_dim = output_dim
+        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width,
+                               kernel_size=patch_size, stride=patch_size,
+                               bias=False)
+
+        scale = width ** -0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(
+            scale * torch.randn((input_resolution // patch_size) ** 2 + 1,
+                                width))
+        self.ln_pre = LayerNorm(width)
+
+        self.transformer = Transformer(width, layers, heads)
+
+        self.ln_post = LayerNorm(width)
+        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+    def forward(self, x: torch.Tensor):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1],
+                      -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(
+            x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
+            dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        x = self.ln_post(x[:, 0, :])
+
+        if self.proj is not None:
+            x = x @ self.proj
+
+        return x
+
+
+class CLIP(nn.Module):
+    def __init__(self,
+                 embed_dim: int,
+                 # vision
+                 image_resolution: int,
+                 vision_layers: Union[Tuple[int, int, int, int], int],
+                 vision_width: int,
+                 vision_patch_size: int,
+                 # text
+                 context_length: int,
+                 vocab_size: int,
+                 transformer_width: int,
+                 transformer_heads: int,
+                 transformer_layers: int
+                 ):
+        super().__init__()
+
+        self.context_length = context_length
+
+        if isinstance(vision_layers, (tuple, list)):
+            vision_heads = vision_width * 32 // 64
+            self.visual = ModifiedResNet(
+                layers=vision_layers,
+                output_dim=embed_dim,
+                heads=vision_heads,
+                input_resolution=image_resolution,
+                width=vision_width
+            )
+        else:
+            vision_heads = vision_width // 64
+            self.visual = VisualTransformer(
+                input_resolution=image_resolution,
+                patch_size=vision_patch_size,
+                width=vision_width,
+                layers=vision_layers,
+                heads=vision_heads,
+                output_dim=embed_dim
+            )
+
+        self.transformer = Transformer(
+            width=transformer_width,
+            layers=transformer_layers,
+            heads=transformer_heads,
+            attn_mask=self.build_attention_mask()
+        )
+
+        self.vocab_size = vocab_size
+        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+        self.positional_embedding = nn.Parameter(
+            torch.empty(self.context_length, transformer_width))
+        self.ln_final = LayerNorm(transformer_width)
+
+        self.text_projection = nn.Parameter(
+            torch.empty(transformer_width, embed_dim))
+        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+        self.initialize_parameters()
+
+    def initialize_parameters(self):
+        nn.init.normal_(self.token_embedding.weight, std=0.02)
+        nn.init.normal_(self.positional_embedding, std=0.01)
+
+        if isinstance(self.visual, ModifiedResNet):
+            if self.visual.attnpool is not None:
+                std = self.visual.attnpool.c_proj.in_features ** -0.5
+                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+            for resnet_block in [self.visual.layer1, self.visual.layer2,
+                                 self.visual.layer3, self.visual.layer4]:
+                for name, param in resnet_block.named_parameters():
+                    if name.endswith("bn3.weight"):
+                        nn.init.zeros_(param)
+
+        proj_std = (self.transformer.width ** -0.5) * (
+            (2 * self.transformer.layers) ** -0.5)
+        attn_std = self.transformer.width ** -0.5
+        fc_std = (2 * self.transformer.width) ** -0.5
+        for block in self.transformer.resblocks:
+            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+        if self.text_projection is not None:
+            nn.init.normal_(self.text_projection,
+                            std=self.transformer.width ** -0.5)
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float("-inf"))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    @property
+    def dtype(self):
+        return self.visual.conv1.weight.dtype
+
+    def encode_image(self, image):
+        return self.visual(image.type(self.dtype))
+
+    def encode_text(self, text):
+        x = self.token_embedding(text).type(
+            self.dtype)  # [batch_size, n_ctx, d_model]
+
+        x = x + self.positional_embedding.type(self.dtype)
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x).type(self.dtype)
+
+        # x.shape = [batch_size, n_ctx, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        x = x[torch.arange(x.shape[0]), text.argmax(
+            dim=-1)] @ self.text_projection
+
+        return x
+
+    def forward(self, image, text):
+        image_features = self.encode_image(image)
+        text_features = self.encode_text(text)
+
+        # normalized features
+        image_features = image_features / image_features.norm(dim=-1,
+                                                              keepdim=True)
+        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+        # cosine similarity as logits
+        logit_scale = self.logit_scale.exp()
+        logits_per_image = logit_scale * image_features @ text_features.t()
+        logits_per_text = logit_scale * text_features @ image_features.t()
+
+        # shape = [global_batch_size, global_batch_size]
+        return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+    """Convert applicable model parameters to fp16"""
+
+    def _convert_weights_to_fp16(l):
+        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+            l.weight.data = l.weight.data.half()
+            if l.bias is not None:
+                l.bias.data = l.bias.data.half()
+
+        if isinstance(l, nn.MultiheadAttention):
+            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
+                         "in_proj_bias", "bias_k", "bias_v"]:
+                tensor = getattr(l, attr)
+                if tensor is not None:
+                    tensor.data = tensor.data.half()
+
+        for name in ["text_projection", "proj"]:
+            if hasattr(l, name):
+                attr = getattr(l, name)
+                if attr is not None:
+                    attr.data = attr.data.half()
+
+    model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+    vit = "visual.proj" in state_dict
+
+    if vit:
+        vision_width = state_dict["visual.conv1.weight"].shape[0]
+        vision_layers = len([k for k in state_dict.keys() if
+                             k.startswith("visual.") and k.endswith(
+                                 ".attn.in_proj_weight")])
+        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+        grid_size = round(
+            (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+        image_resolution = vision_patch_size * grid_size
+    else:
+        counts: list = [len(set(k.split(".")[2] for k in state_dict if
+                                k.startswith(f"visual.layer{b}"))) for b in
+                        [1, 2, 3, 4]]
+        vision_layers = tuple(counts)
+        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+        output_width = round((state_dict[
+            "visual.attnpool.positional_embedding"].shape[
+            0] - 1) ** 0.5)
+        vision_patch_size = None
+        assert output_width ** 2 + 1 == \
+            state_dict["visual.attnpool.positional_embedding"].shape[0]
+        image_resolution = output_width * 32
+
+    embed_dim = state_dict["text_projection"].shape[1]
+    context_length = state_dict["positional_embedding"].shape[0]
+    vocab_size = state_dict["token_embedding.weight"].shape[0]
+    transformer_width = state_dict["ln_final.weight"].shape[0]
+    transformer_heads = transformer_width // 64
+    transformer_layers = len(set(k.split(".")[2] for k in state_dict if
+                                 k.startswith(f"transformer.resblocks")))
+
+    model = CLIP(
+        embed_dim,
+        image_resolution, vision_layers, vision_width, vision_patch_size,
+        context_length, vocab_size, transformer_width, transformer_heads,
+        transformer_layers
+    )
+
+    for key in ["input_resolution", "context_length", "vocab_size"]:
+        if key in state_dict:
+            del state_dict[key]
+
+    convert_weights(model)
+    model.load_state_dict(state_dict)
+    return model.eval()
diff --git a/imaginaire/evaluation/caption/common.py b/imaginaire/evaluation/caption/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed54e55335b9b9684eab2fc58cbe14e810597dd6
--- /dev/null
+++ b/imaginaire/evaluation/caption/common.py
@@ -0,0 +1,57 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import boto3
+import torch
+from torch import nn, distributed as dist
+from torch.nn import functional as F
+from torch.distributed import barrier
+
+from imaginaire.utils.distributed import is_local_master
+from .clip import build_model
+from ...utils.io import download_file_from_google_drive
+
+
+def get_image_encoder(aws_credentials=None):
+    if dist.is_initialized() and not is_local_master():
+        # Make sure only the first process in distributed training downloads the model, and the others use the cache.
+        barrier()
+
+    # Load the CLIP image encoder.
+    print("Loading CLIP image encoder.")
+    model_path = os.path.join(torch.hub.get_dir(), 'checkpoints', 'ViT-B-32.pt')
+    if not os.path.exists(model_path):
+        if aws_credentials is not None:
+            s3 = boto3.client('s3', **aws_credentials)
+            s3.download_file('lpi-poe', 'model_zoo/ViT-B-32.pt', model_path)
+        else:
+            download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path)
+    model = torch.load(model_path, map_location='cpu')
+
+    if dist.is_initialized() and is_local_master():
+        # Make sure only the first process in distributed training downloads the model, and the others use the cache.
+        barrier()
+
+    encoder = build_model(model).cuda()
+    return ImageEncoder(encoder)
+
+
+class ImageEncoder(nn.Module):
+    def __init__(self, encoder):
+        super().__init__()
+        self.model = encoder
+        self.image_size = self.model.visual.input_resolution
+        self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda")
+        self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda")
+
+    @torch.no_grad()
+    def forward(self, data, fake_images, align_corners=True):
+        images = 0.5 * (1 + fake_images)
+        images = F.interpolate(images, (self.image_size, self.image_size), mode='bicubic', align_corners=align_corners)
+        images.clamp_(0, 1)
+        images = (images - self.mean[None, :, None, None]) / (self.std[None, :, None, None])
+        image_code = self.model.encode_image(images)
+        return torch.cat((image_code, data['captions-clip']), dim=1)
diff --git a/imaginaire/evaluation/caption/r_precision.py b/imaginaire/evaluation/caption/r_precision.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec192e67cf38d24f4f238d151661367f9a5fa753
--- /dev/null
+++ b/imaginaire/evaluation/caption/r_precision.py
@@ -0,0 +1,27 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa
+
+import torch
+import torch.nn.functional as F
+
+
+def get_r_precision(image_text_code, eps=1e-5):
+    all_image_code, all_text_code = torch.chunk(image_text_code, 2, dim=1)
+    P_rates = []
+    num_samples = len(all_image_code)
+    assert num_samples >= 100
+    for i in range(0, num_samples, 100):
+        if i + 100 <= num_samples:
+            cur_image_code = all_image_code[i:i + 100]
+            cur_text_code = all_text_code[i:i + 100]
+            cur_image_code = F.normalize(cur_image_code, dim=1, eps=eps)
+            cur_text_code = F.normalize(cur_text_code, dim=1, eps=eps)
+            cosine_similarities = cur_image_code @ cur_text_code.T
+            top1_indices = torch.topk(cosine_similarities, dim=1, k=1)[1][:, 0]
+            P_rate = torch.sum(top1_indices == torch.arange(100, device=top1_indices.device)).item()
+            P_rates.append(P_rate)
+    A_precision = sum(P_rates) * 1.0 / len(P_rates)
+    return {"caption_rprec": A_precision}
diff --git a/imaginaire/evaluation/common.py b/imaginaire/evaluation/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..046c47da6692f5bd24543cc0e72da2a86b7ef2a3
--- /dev/null
+++ b/imaginaire/evaluation/common.py
@@ -0,0 +1,651 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import math
+import os
+from functools import partial
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models import inception_v3
+from cleanfid.features import feature_extractor
+from cleanfid.resize import build_resizer
+
+from imaginaire.evaluation.lpips import get_lpips_model
+from imaginaire.evaluation.segmentation import get_segmentation_hist_model, get_miou
+from imaginaire.evaluation.caption import get_image_encoder, get_r_precision
+from imaginaire.evaluation.pretrained import TFInceptionV3, InceptionV3, Vgg16, SwAV
+from imaginaire.utils.distributed import (dist_all_gather_tensor, get_rank,
+                                          get_world_size, is_master,
+                                          is_local_master)
+from imaginaire.utils.distributed import master_only_print
+from imaginaire.utils.misc import apply_imagenet_normalization, to_cuda
+
+
+@torch.no_grad()
+def compute_all_metrics(act_dir,
+                        data_loader,
+                        net_G,
+                        key_real='images',
+                        key_fake='fake_images',
+                        sample_size=None,
+                        preprocess=None,
+                        is_video=False,
+                        few_shot_video=False,
+                        kid_num_subsets=1,
+                        kid_subset_size=None,
+                        key_prefix='',
+                        prdc_k=5,
+                        metrics=None,
+                        dataset_name='',
+                        aws_credentials=None,
+                        **kwargs):
+    r"""
+    Args:
+        act_dir (string): Path to a directory to temporarily save feature activations.
+        data_loader (obj): PyTorch dataloader object.
+        net_G (obj): The generator module.
+        key_real (str): Dictionary key value for the real data.
+        key_fake (str): Dictionary key value for the fake data.
+        sample_size (int or None): How many samples to use for FID.
+        preprocess (func or None): Pre-processing function to use.
+        is_video (bool): Whether we are handling video sequences.
+        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+        kid_num_subsets (int): Number of subsets for KID evaluation.
+        kid_subset_size (int or None): The number of samples in each subset for KID evaluation.
+        key_prefix (string): Add this string before all keys of the output dictionary.
+        prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage.
+        metrics (list of strings): Which metrics we want to evaluate.
+        dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to
+            use for segmentation evaluation.
+    Returns:
+        batch_y (tensor): Inception features of the current batch. Note that
+            only the master gpu will get it.
+    """
+
+    from imaginaire.evaluation.fid import _calculate_frechet_distance
+    from imaginaire.evaluation.kid import _polynomial_mmd_averages
+    from imaginaire.evaluation.prdc import _get_prdc
+    from imaginaire.evaluation.msid import _get_msid
+    from imaginaire.evaluation.knn import _get_1nn_acc
+    if metrics is None:
+        metrics = []
+    act_path = os.path.join(act_dir, 'activations_real.pt')
+
+    # Get feature activations and other outputs computed from fake images.
+    output_module_dict = nn.ModuleDict()
+    if "seg_mIOU" in metrics:
+        output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials)
+    if "caption_rprec" in metrics:
+        output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials)
+    if "LPIPS" in metrics:
+        output_module_dict["LPIPS"] = get_lpips_model()
+
+    fake_outputs = get_outputs(
+        data_loader, key_real, key_fake, net_G, sample_size, preprocess,
+        output_module_dict=output_module_dict, **kwargs
+    )
+    fake_act = fake_outputs["activations"]
+
+    # Get feature activations computed from real images.
+    real_act = load_or_compute_activations(
+        act_path, data_loader, key_real, key_fake, None,
+        sample_size, preprocess, is_video=is_video,
+        few_shot_video=few_shot_video, **kwargs
+    )
+
+    metrics_from_activations = {
+        "1NN": _get_1nn_acc,
+        "MSID": _get_msid,
+        "FID": _calculate_frechet_distance,
+        "KID": partial(_polynomial_mmd_averages,
+                       n_subsets=kid_num_subsets,
+                       subset_size=kid_subset_size,
+                       ret_var=True),
+        "PRDC": partial(_get_prdc, nearest_k=prdc_k)
+    }
+
+    other_metrics = {
+        "seg_mIOU": get_miou,
+        "caption_rprec": get_r_precision,
+        "LPIPS": lambda x: {"LPIPS": torch.mean(x).item()}
+    }
+
+    all_metrics = {}
+    if is_master():
+        for metric in metrics:
+            if metric in metrics_from_activations:
+                metric_function = metrics_from_activations[metric]
+                metric_dict = metric_function(real_act, fake_act)
+            elif metric in other_metrics:
+                metric_function = other_metrics[metric]
+                if fake_outputs[metric] is not None:
+                    metric_dict = metric_function(fake_outputs[metric])
+            else:
+                print(f"{metric} is not implemented!")
+                raise NotImplementedError
+            for k, v in metric_dict.items():
+                all_metrics.update({key_prefix + k: v})
+    if dist.is_initialized():
+        dist.barrier()
+    return all_metrics
+
+
+@torch.no_grad()
+def compute_all_metrics_data(data_loader_a,
+                             data_loader_b,
+                             key_a='images',
+                             key_b='images',
+                             sample_size=None,
+                             preprocess=None,
+                             kid_num_subsets=1,
+                             kid_subset_size=None,
+                             key_prefix='',
+                             prdc_k=5,
+                             metrics=None,
+                             dataset_name='',
+                             aws_credentials=None,
+                             **kwargs):
+    r"""
+    Args:
+        act_dir (string): Path to a directory to temporarily save feature activations.
+        data_loader (obj): PyTorch dataloader object.
+        net_G (obj): The generator module.
+        key_a (str): Dictionary key value for the real data.
+        key_b (str): Dictionary key value for the fake data.
+        sample_size (int or None): How many samples to use for FID.
+        preprocess (func or None): Pre-processing function to use.
+        is_video (bool): Whether we are handling video sequences.
+        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+        kid_num_subsets (int): Number of subsets for KID evaluation.
+        kid_subset_size (int or None): The number of samples in each subset for KID evaluation.
+        key_prefix (string): Add this string before all keys of the output dictionary.
+        prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage.
+        metrics (list of strings): Which metrics we want to evaluate.
+        dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to
+            use for segmentation evaluation.
+    Returns:
+        batch_y (tensor): Inception features of the current batch. Note that
+            only the master gpu will get it.
+    """
+
+    from imaginaire.evaluation.fid import _calculate_frechet_distance
+    from imaginaire.evaluation.kid import _polynomial_mmd_averages
+    from imaginaire.evaluation.prdc import _get_prdc
+    from imaginaire.evaluation.msid import _get_msid
+    from imaginaire.evaluation.knn import _get_1nn_acc
+    if metrics is None:
+        metrics = []
+
+    min_data_size = min(len(data_loader_a.dataset),
+                        len(data_loader_b.dataset))
+    if sample_size is None:
+        sample_size = min_data_size
+    else:
+        sample_size = min(sample_size, min_data_size)
+
+    # Get feature activations and other outputs computed from fake images.
+    output_module_dict = nn.ModuleDict()
+    if "seg_mIOU" in metrics:
+        output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials)
+    if "caption_rprec" in metrics:
+        output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials)
+    if "LPIPS" in metrics:
+        output_module_dict["LPIPS"] = get_lpips_model()
+
+    fake_outputs = get_outputs(
+        data_loader_b, key_a, key_b, None, sample_size, preprocess,
+        output_module_dict=output_module_dict, **kwargs
+    )
+    act_b = fake_outputs["activations"]
+
+    act_a = load_or_compute_activations(
+        None, data_loader_a, key_a, key_b, None, sample_size, preprocess,
+        output_module_dict=output_module_dict, **kwargs
+    )
+
+    # act_b = load_or_compute_activations(
+    #     None, data_loader_b, key_a, key_b, None, sample_size, preprocess,
+    #     output_module_dict=output_module_dict, generate_twice=generate_twice, **kwargs
+    # )
+
+    metrics_from_activations = {
+        "1NN": _get_1nn_acc,
+        "MSID": _get_msid,
+        "FID": _calculate_frechet_distance,
+        "KID": partial(_polynomial_mmd_averages,
+                       n_subsets=kid_num_subsets,
+                       subset_size=kid_subset_size,
+                       ret_var=True),
+        "PRDC": partial(_get_prdc, nearest_k=prdc_k)
+    }
+
+    other_metrics = {
+        "seg_mIOU": get_miou,
+        "caption_rprec": get_r_precision,
+        "LPIPS": lambda x: {"LPIPS": torch.mean(x).item()}
+    }
+
+    all_metrics = {}
+    if is_master():
+        for metric in metrics:
+            if metric in metrics_from_activations:
+                metric_function = metrics_from_activations[metric]
+                metric_dict = metric_function(act_a, act_b)
+            elif metric in other_metrics:
+                metric_function = other_metrics[metric]
+                if fake_outputs[metric] is not None:
+                    metric_dict = metric_function(fake_outputs[metric])
+            else:
+                print(f"{metric} is not implemented!")
+                raise NotImplementedError
+            for k, v in metric_dict.items():
+                all_metrics.update({key_prefix + k: v})
+    if dist.is_initialized():
+        dist.barrier()
+    return all_metrics
+
+
+@torch.no_grad()
+def get_activations(data_loader, key_real, key_fake,
+                    generator=None, sample_size=None, preprocess=None,
+                    align_corners=True, network='inception', **kwargs):
+    r"""Compute activation values and pack them in a list.
+
+    Args:
+        data_loader (obj): PyTorch dataloader object.
+        key_real (str): Dictionary key value for the real data.
+        key_fake (str): Dictionary key value for the fake data.
+        generator (obj): PyTorch trainer network.
+        sample_size (int): How many samples to use for FID.
+        preprocess (func): Pre-processing function to use.
+        align_corners (bool): The ``'align_corners'`` parameter to be used for
+            `torch.nn.functional.interpolate`.
+    Returns:
+        batch_y (tensor): Inception features of the current batch. Note that
+            only the master gpu will get it.
+    """
+    if dist.is_initialized() and not is_local_master():
+        # Make sure only the first process in distributed training downloads
+        # the model, and the others will use the cache
+        # noinspection PyUnresolvedReferences
+        torch.distributed.barrier()
+
+    if network == 'tf_inception':
+        model = TFInceptionV3()
+    elif network == 'inception':
+        model = InceptionV3()
+    elif network == 'vgg16':
+        model = Vgg16()
+    elif network == 'swav':
+        model = SwAV()
+    elif network == 'clean_inception':
+        model = CleanInceptionV3()
+    else:
+        raise NotImplementedError(f'Network "{network}" is not supported!')
+
+    if dist.is_initialized() and is_local_master():
+        # Make sure only the first process in distributed training downloads
+        # the model, and the others will use the cache
+        # noinspection PyUnresolvedReferences
+        dist.barrier()
+
+    model = model.to('cuda').eval()
+    world_size = get_world_size()
+    batch_y = []
+
+    # Iterate through the dataset to compute the activation.
+    for it, data in enumerate(data_loader):
+        data = to_cuda(data)
+        # Preprocess the data.
+        if preprocess is not None:
+            data = preprocess(data)
+        # Load real data if the generator is not specified.
+        if generator is None:
+            images = data[key_real]
+        else:
+            # Compute the generated image.
+            net_G_output = generator(data, **kwargs)
+            images = net_G_output[key_fake]
+        # Clamp the image for models that do not set the output to between
+        # -1, 1. For models that employ tanh, this has no effect.
+        images.clamp_(-1, 1)
+        y = model(images, align_corners=align_corners)
+        batch_y.append(y)
+        if sample_size is not None and \
+                data_loader.batch_size * world_size * (it + 1) >= sample_size:
+            # Reach the number of samples we need.
+            break
+
+    batch_y = torch.cat(dist_all_gather_tensor(torch.cat(batch_y)))
+    if sample_size is not None:
+        batch_y = batch_y[:sample_size]
+    print(f"Computed feature activations of size {batch_y.shape}")
+    return batch_y
+
+
+class CleanInceptionV3(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.model = feature_extractor(name="torchscript_inception", resize_inside=False)
+
+    def forward(self, img_batch, transform=True, **_kwargs):
+        if transform:
+            # Assume the input is (-1, 1). We transform it to (0, 255) and round it to the closest integer.
+            img_batch = torch.round(255 * (0.5 * img_batch + 0.5))
+        resized_batch = clean_resize(img_batch)
+        return self.model(resized_batch)
+
+
+def clean_resize(img_batch):
+    # Resize images from arbitrary resolutions to 299x299.
+    batch_size = img_batch.size(0)
+    img_batch = img_batch.cpu().numpy()
+    fn_resize = build_resizer('clean')
+    resized_batch = torch.zeros(batch_size, 3, 299, 299, device='cuda')
+    for idx in range(batch_size):
+        curr_img = img_batch[idx]
+        img_np = curr_img.transpose((1, 2, 0))
+        img_resize = fn_resize(img_np)
+        resized_batch[idx] = torch.tensor(img_resize.transpose((2, 0, 1)), device='cuda')
+    resized_batch = resized_batch.cuda()
+    return resized_batch
+
+
+@torch.no_grad()
+def get_outputs(data_loader, key_real, key_fake,
+                generator=None, sample_size=None, preprocess=None,
+                align_corners=True, network='inception',
+                output_module_dict=None, **kwargs):
+    r"""Compute activation values and pack them in a list.
+
+    Args:
+        data_loader (obj): PyTorch dataloader object.
+        key_real (str): Dictionary key value for the real data.
+        key_fake (str): Dictionary key value for the fake data.
+        generator (obj): PyTorch trainer network.
+        sample_size (int): How many samples to use for FID.
+        preprocess (func): Pre-processing function to use.
+        align_corners (bool): The ``'align_corners'`` parameter to be used for `torch.nn.functional.interpolate`.
+    Returns:
+        batch_y (tensor): Inception features of the current batch. Note that
+            only the master gpu will get it.
+    """
+    if output_module_dict is None:
+        output_module_dict = nn.ModuleDict()
+    if dist.is_initialized() and not is_local_master():
+        # Make sure only the first process in distributed training downloads
+        # the model, and the others will use the cache
+        # noinspection PyUnresolvedReferences
+        torch.distributed.barrier()
+
+    if network == 'tf_inception':
+        model = TFInceptionV3()
+    elif network == 'inception':
+        model = InceptionV3()
+    elif network == 'vgg16':
+        model = Vgg16()
+    elif network == 'swav':
+        model = SwAV()
+    elif network == 'clean_inception':
+        model = CleanInceptionV3()
+    else:
+        raise NotImplementedError(f'Network "{network}" is not supported!')
+
+    if dist.is_initialized() and is_local_master():
+        # Make sure only the first process in distributed training downloads
+        # the model, and the others will use the cache
+        # noinspection PyUnresolvedReferences
+        dist.barrier()
+
+    model = model.to('cuda').eval()
+    world_size = get_world_size()
+    output = {}
+    for k in output_module_dict.keys():
+        output[k] = []
+    output["activations"] = []
+
+    # Iterate through the dataset to compute the activation.
+    for it, data in enumerate(data_loader):
+        data = to_cuda(data)
+        # Preprocess the data.
+        if preprocess is not None:
+            data = preprocess(data)
+        # Load real data if the generator is not specified.
+        if generator is None:
+            images = data[key_real]
+        else:
+            # Compute the generated image.
+            net_G_output = generator(data, **kwargs)
+            images = net_G_output[key_fake]
+        for metric_name, metric_module in output_module_dict.items():
+            if metric_module is not None:
+                if metric_name == 'LPIPS':
+                    assert generator is not None
+                    net_G_output_another = generator(data, **kwargs)
+                    images_another = net_G_output_another[key_fake]
+                    output[metric_name].append(metric_module(images, images_another))
+                else:
+                    output[metric_name].append(metric_module(data, images, align_corners=align_corners))
+        # Clamp the image for models that do not set the output to between
+        # -1, 1. For models that employ tanh, this has no effect.
+        images.clamp_(-1, 1)
+        y = model(images, align_corners=align_corners)
+        output["activations"].append(y)
+        if sample_size is not None and data_loader.batch_size * world_size * (it + 1) >= sample_size:
+            # Reach the number of samples we need.
+            break
+
+    for k, v in output.items():
+        if len(v) > 0:
+            output[k] = torch.cat(dist_all_gather_tensor(torch.cat(v)))[:sample_size]
+        else:
+            output[k] = None
+    return output
+
+
+@torch.no_grad()
+def get_video_activations(data_loader, key_real, key_fake, trainer=None,
+                          sample_size=None, preprocess=None, few_shot=False):
+    r"""Compute activation values and pack them in a list. We do not do all
+    reduce here.
+
+    Args:
+        data_loader (obj): PyTorch dataloader object.
+        key_real (str): Dictionary key value for the real data.
+        key_fake (str): Dictionary key value for the fake data.
+        trainer (obj): Trainer. Video generation is more involved, we rely on
+            the "reset" and "test" function to conduct the evaluation.
+        sample_size (int): For computing video activation, we will use .
+        preprocess (func): The preprocess function to be applied to the data.
+        few_shot (bool): If ``True``, uses the few-shot setting.
+    Returns:
+        batch_y (tensor): Inception features of the current batch. Note that
+            only the master gpu will get it.
+    """
+    inception = inception_init()
+    batch_y = []
+
+    # We divide video sequences to different GPUs for testing.
+    num_sequences = data_loader.dataset.num_inference_sequences()
+    if sample_size is None:
+        num_videos_to_test = 10
+        num_frames_per_video = 5
+    else:
+        num_videos_to_test, num_frames_per_video = sample_size
+    if num_videos_to_test == -1:
+        num_videos_to_test = num_sequences
+    else:
+        num_videos_to_test = min(num_videos_to_test, num_sequences)
+    master_only_print('Number of videos used for evaluation: {}'.format(num_videos_to_test))
+    master_only_print('Number of frames per video used for evaluation: {}'.format(num_frames_per_video))
+
+    world_size = get_world_size()
+    if num_videos_to_test < world_size:
+        seq_to_run = [get_rank() % num_videos_to_test]
+    else:
+        num_videos_to_test = num_videos_to_test // world_size * world_size
+        seq_to_run = range(get_rank(), num_videos_to_test, world_size)
+
+    for sequence_idx in seq_to_run:
+        data_loader = set_sequence_idx(few_shot, data_loader, sequence_idx)
+        if trainer is not None:
+            trainer.reset()
+        for it, data in enumerate(data_loader):
+            if few_shot and it == 0:
+                continue
+            if it >= num_frames_per_video:
+                break
+
+            # preprocess the data is preprocess is not none.
+            if trainer is not None:
+                data = trainer.pre_process(data)
+            elif preprocess is not None:
+                data = preprocess(data)
+            data = to_cuda(data)
+
+            if trainer is None:
+                images = data[key_real][:, -1]
+            else:
+                net_G_output = trainer.test_single(data)
+                images = net_G_output[key_fake]
+            y = inception_forward(inception, images)
+            batch_y += [y]
+
+    batch_y = torch.cat(batch_y)
+    batch_y = dist_all_gather_tensor(batch_y)
+    if is_local_master():
+        batch_y = torch.cat(batch_y)
+    return batch_y
+
+
+def inception_init():
+    inception = inception_v3(pretrained=True, transform_input=False)
+    inception = inception.to('cuda')
+    inception.eval()
+    inception.fc = torch.nn.Sequential()
+    return inception
+
+
+def inception_forward(inception, images):
+    images.clamp_(-1, 1)
+    images = apply_imagenet_normalization(images)
+    images = F.interpolate(images, size=(299, 299),
+                           mode='bicubic', align_corners=True)
+    return inception(images)
+
+
+def gather_tensors(batch_y):
+    batch_y = torch.cat(batch_y)
+    batch_y = dist_all_gather_tensor(batch_y)
+    if is_local_master():
+        batch_y = torch.cat(batch_y)
+    return batch_y
+
+
+def set_sequence_idx(few_shot, data_loader, sequence_idx):
+    r"""Get sequence index
+
+    Args:
+        few_shot (bool): If ``True``, uses the few-shot setting.
+        data_loader: dataloader object
+        sequence_idx (int): which sequence to use.
+    """
+    if few_shot:
+        data_loader.dataset.set_inference_sequence_idx(sequence_idx,
+                                                       sequence_idx,
+                                                       0)
+    else:
+        data_loader.dataset.set_inference_sequence_idx(sequence_idx)
+    return data_loader
+
+
+def load_or_compute_activations(act_path, data_loader, key_real, key_fake,
+                                generator=None, sample_size=None,
+                                preprocess=None,
+                                is_video=False, few_shot_video=False,
+                                **kwargs):
+    r"""Load mean and covariance from saved npy file if exists. Otherwise,
+    compute the mean and covariance.
+
+    Args:
+        act_path (str or None): Location for the numpy file to store or to load
+            the activations.
+        data_loader (obj): PyTorch dataloader object.
+        key_real (str): Dictionary key value for the real data.
+        key_fake (str): Dictionary key value for the fake data.
+        generator (obj): PyTorch trainer network.
+        sample_size (int): How many samples to be used for computing the KID.
+        preprocess (func): The preprocess function to be applied to the data.
+        is_video (bool): Whether we are handling video sequences.
+        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+    Returns:
+        (torch.Tensor) Feature activations.
+    """
+    if act_path is not None and os.path.exists(act_path):
+        # Loading precomputed activations.
+        print('Load activations from {}'.format(act_path))
+        act = torch.load(act_path, map_location='cpu').cuda()
+    else:
+        # Compute activations.
+        if is_video:
+            act = get_video_activations(
+                data_loader, key_real, key_fake, generator,
+                sample_size, preprocess, few_shot_video, **kwargs
+            )
+        else:
+            act = get_activations(
+                data_loader, key_real, key_fake, generator,
+                sample_size, preprocess, **kwargs
+            )
+        if act_path is not None and is_local_master():
+            print('Save activations to {}'.format(act_path))
+            if not os.path.exists(os.path.dirname(act_path)):
+                os.makedirs(os.path.dirname(act_path), exist_ok=True)
+            torch.save(act, act_path)
+    return act
+
+
+def compute_pairwise_distance(data_x, data_y=None, num_splits=10):
+    r"""
+
+    Args:
+        data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
+        data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
+    Returns:
+        numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
+    """
+    if data_y is None:
+        data_y = data_x
+    num_samples = data_x.shape[0]
+    assert data_x.shape[0] == data_y.shape[0]
+    dists = []
+    for i in range(num_splits):
+        batch_size = math.ceil(num_samples / num_splits)
+        start_idx = i * batch_size
+        end_idx = min((i + 1) * batch_size, num_samples)
+        dists.append(torch.cdist(data_x[start_idx:end_idx],
+                                 data_y).cpu())
+    dists = torch.cat(dists, dim=0)
+    return dists
+
+
+def compute_nn(input_features, k, num_splits=50):
+    num_samples = input_features.shape[0]
+    all_indices = []
+    all_values = []
+    for i in range(num_splits):
+        batch_size = math.ceil(num_samples / num_splits)
+        start_idx = i * batch_size
+        end_idx = min((i + 1) * batch_size, num_samples)
+        dist = torch.cdist(input_features[start_idx:end_idx],
+                           input_features)
+        dist[:, start_idx:end_idx] += torch.diag(
+            float('inf') * torch.ones(dist.size(0), device=dist.device)
+        )
+        k_smallests, indices = torch.topk(dist, k, dim=-1, largest=False)
+        all_indices.append(indices)
+        all_values.append(k_smallests)
+    return torch.cat(all_values, dim=0), torch.cat(all_indices, dim=0)
diff --git a/imaginaire/evaluation/fid.py b/imaginaire/evaluation/fid.py
new file mode 100644
index 0000000000000000000000000000000000000000..793f4b77dd50381abffc82d03fee8bc36e746dab
--- /dev/null
+++ b/imaginaire/evaluation/fid.py
@@ -0,0 +1,143 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+import numpy as np
+import torch
+from scipy import linalg
+
+from imaginaire.evaluation.common import load_or_compute_activations
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+
+
+@torch.no_grad()
+def compute_fid(fid_path, data_loader, net_G,
+                key_real='images', key_fake='fake_images',
+                sample_size=None, preprocess=None, return_act=False,
+                is_video=False, few_shot_video=False, **kwargs):
+    r"""Compute the fid score.
+
+    Args:
+        fid_path (str): Location for the numpy file to store or to load the
+            statistics.
+        data_loader (obj): PyTorch dataloader object.
+        net_G (obj): For image generation modes, net_G is the generator network.
+            For video generation models, net_G is the trainer.
+        key_real (str): Dictionary key value for the real data.
+        key_fake (str): Dictionary key value for the fake data.
+        sample_size (int or tuple): How many samples to be used.
+        preprocess (func): The preprocess function to be applied to the data.
+        return_act (bool): If ``True``, also returns feature activations of
+            real and fake data.
+        is_video (bool): Whether we are handling video sequences.
+        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+    Returns:
+        (float): FID value.
+    """
+    print('Computing FID.')
+    act_path = os.path.join(os.path.dirname(fid_path),
+                            'activations_real.npy')
+    # Get the fake mean and covariance.
+    fake_act = load_or_compute_activations(
+        None, data_loader, key_real, key_fake, net_G,
+        sample_size, preprocess, is_video=is_video,
+        few_shot_video=few_shot_video, **kwargs
+    )
+
+    # Get the ground truth mean and covariance.
+    real_act = load_or_compute_activations(
+        act_path, data_loader, key_real, key_fake, None,
+        sample_size, preprocess, is_video=is_video,
+        few_shot_video=few_shot_video, **kwargs
+    )
+
+    if is_master():
+        fid = _calculate_frechet_distance(
+            fake_act, real_act)["FID"]
+        if return_act:
+            return fid, real_act, fake_act
+        else:
+            return fid
+    elif return_act:
+        return None, None, None
+    else:
+        return None
+
+
+@torch.no_grad()
+def compute_fid_data(fid_path, data_loader_a, data_loader_b,
+                     key_a='images', key_b='images', sample_size=None,
+                     is_video=False, few_shot_video=False, **kwargs):
+    r"""Compute the fid score between two datasets.
+
+    Args:
+        fid_path (str): Location for the numpy file to store or to load the
+            statistics.
+        data_loader_a (obj): PyTorch dataloader object for dataset a.
+        data_loader_b (obj): PyTorch dataloader object for dataset b.
+        key_a (str): Dictionary key value for images in the dataset a.
+        key_b (str): Dictionary key value for images in the dataset b.
+        sample_size (int): How many samples to be used for computing the FID.
+        is_video (bool): Whether we are handling video sequences.
+        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+    Returns:
+        (float): FID value.
+    """
+    print('Computing FID.')
+    path_a = os.path.join(os.path.dirname(fid_path),
+                          'activations_a.npy')
+    min_data_size = min(len(data_loader_a.dataset),
+                        len(data_loader_b.dataset))
+    if sample_size is None:
+        sample_size = min_data_size
+    else:
+        sample_size = min(sample_size, min_data_size)
+
+    act_a = load_or_compute_activations(
+        path_a, data_loader_a, key_a, key_b, None,
+        sample_size=sample_size, is_video=is_video,
+        few_shot_video=few_shot_video, **kwargs
+    )
+    act_b = load_or_compute_activations(
+        None, data_loader_b, key_a, key_b, None,
+        sample_size=sample_size, is_video=is_video,
+        few_shot_video=few_shot_video, **kwargs
+    )
+
+    if is_master():
+        return _calculate_frechet_distance(act_a, act_b)["FID"]
+
+
+def _calculate_frechet_distance(act_1, act_2, eps=1e-6):
+    mu1 = np.mean(act_1.cpu().numpy(), axis=0)
+    sigma1 = np.cov(act_1.cpu().numpy(), rowvar=False)
+    mu2 = np.mean(act_2.cpu().numpy(), axis=0)
+    sigma2 = np.cov(act_2.cpu().numpy(), rowvar=False)
+    mu1 = np.atleast_1d(mu1)
+    mu2 = np.atleast_1d(mu2)
+    sigma1 = np.atleast_2d(sigma1)
+    sigma2 = np.atleast_2d(sigma2)
+    assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
+    assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
+    diff = mu1 - mu2
+    # Product might be almost singular
+    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+    if not np.isfinite(covmean).all():
+        msg = ('fid calculation produces singular product; '
+               'adding %s to diagonal of cov estimates') % eps
+        print(msg)
+        offset = np.eye(sigma1.shape[0]) * eps
+        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+    # Numerical error might give slight imaginary component
+    if np.iscomplexobj(covmean):
+        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+            m = np.max(np.abs(covmean.imag))
+            print('Imaginary component {}'.format(m))
+            # raise ValueError('Imaginary component {}'.format(m))
+        covmean = covmean.real
+    tr_covmean = np.trace(covmean)
+    return {"FID": (diff.dot(diff) + np.trace(sigma1) + np.trace(
+        sigma2) - 2 * tr_covmean)}
diff --git a/imaginaire/evaluation/kid.py b/imaginaire/evaluation/kid.py
new file mode 100644
index 0000000000000000000000000000000000000000..675b93015b03a8c1f4557e5687de0887b1f5a0a4
--- /dev/null
+++ b/imaginaire/evaluation/kid.py
@@ -0,0 +1,317 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+"""
+Modified from https://github.com/abdulfatir/gan-metrics-pytorch
+Copyright 2018 Institute of Bioinformatics, JKU Linz
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+   http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import warnings
+
+import numpy as np
+import torch
+
+from imaginaire.evaluation.common import get_activations, \
+    load_or_compute_activations
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+
+
+@torch.no_grad()
+def compute_kid(kid_path, data_loader, net_G,
+                key_real='images', key_fake='fake_images',
+                real_act=None, fake_act=None,
+                sample_size=None, preprocess=None, is_video=False,
+                save_act=True, num_subsets=1, subset_size=None, **kwargs):
+    r"""Compute the kid score.
+
+    Args:
+        kid_path (str): Location for store feature activations.
+        data_loader (obj): PyTorch dataloader object.
+        net_G (obj): For image generation modes, net_G is the PyTorch trainer
+            network. For video generation models, net_G is the trainer
+            because video generation requires more complicated processing.
+        key_real (str): Dictionary key value for the real data.
+        key_fake (str): Dictionary key value for the fake data.
+        real_act (torch.Tensor or None): Feature activations of real data.
+        fake_act (torch.Tensor or None): Feature activations of fake data.
+        sample_size (int): How many samples to be used for computing feature
+            activations.
+        preprocess (func): The preprocess function to be applied to the data.
+        is_video (bool): Whether we are handling video sequences.
+        save_act (bool): If ``True``, saves real activations to the disk and
+            reload them in the future. It might save some computation but will
+            cost storage.
+        num_subsets (int): Number of subsets to sample from all the samples.
+        subset_size (int): Number of samples in each subset.
+    Returns:
+        kid (float): KID value.
+    """
+    print('Computing KID.')
+    act_path = os.path.join(
+        os.path.dirname(kid_path), 'activations_real.npy'
+    ) if save_act else None
+
+    # Get the fake activations.
+    if fake_act is None:
+        fake_act = load_or_compute_activations(None,
+                                               data_loader,
+                                               key_real, key_fake, net_G,
+                                               sample_size, preprocess,
+                                               is_video=is_video, **kwargs)
+    else:
+        print(f"Using precomputed activations of size {fake_act.shape}.")
+
+    # Get the ground truth activations.
+    if real_act is None:
+        real_act = load_or_compute_activations(act_path,
+                                               data_loader,
+                                               key_real, key_fake, None,
+                                               sample_size, preprocess,
+                                               is_video=is_video, **kwargs)
+    else:
+        print(f"Using precomputed activations of size {real_act.shape}.")
+
+    if is_master():
+        return _polynomial_mmd_averages(fake_act, real_act,
+                                        num_subsets,
+                                        subset_size,
+                                        ret_var=True)["KID"]
+
+
+@torch.no_grad()
+def compute_kid_data(kid_path, data_loader_a, data_loader_b,
+                     key_a='images', key_b='images', sample_size=None,
+                     is_video=False, num_subsets=1, subset_size=None,
+                     **kwargs):
+    r"""Compute the kid score between two datasets.
+
+    Args:
+        kid_path (str): Location for store feature activations.
+        data_loader_a (obj): PyTorch dataloader object for dataset a.
+        data_loader_b (obj): PyTorch dataloader object for dataset b.
+        key_a (str): Dictionary key value for images in the dataset a.
+        key_b (str): Dictionary key value for images in the dataset b.
+        sample_size (int): How many samples to be used for computing the KID.
+        is_video (bool): Whether we are handling video sequences.
+        num_subsets (int): Number of subsets to sample from the whole data.
+        subset_size (int): Number of samples in each subset.
+    Returns:
+        kid (float): KID value.
+    """
+    min_data_size = min(len(data_loader_a.dataset),
+                        len(data_loader_b.dataset))
+    if sample_size is None:
+        sample_size = min_data_size
+    else:
+        sample_size = min(sample_size, min_data_size)
+    print('Computing KID using {} images from both distributions.'.
+          format(sample_size))
+    path_a = os.path.join(os.path.dirname(kid_path),
+                          'activations_a.npy')
+    act_a = load_or_compute_activations(path_a, data_loader_a,
+                                        key_a, key_a,
+                                        sample_size=sample_size,
+                                        is_video=is_video, **kwargs)
+    act_b = get_activations(data_loader_b, key_b, key_b,
+                            None, sample_size, None, **kwargs)
+
+    if is_master():
+        return _polynomial_mmd_averages(act_a, act_b,
+                                        num_subsets,
+                                        subset_size,
+                                        ret_var=True)["KID"]
+
+
+def _polynomial_mmd_averages(codes_g, codes_r, n_subsets, subset_size,
+                             ret_var=True, **kernel_args):
+    r"""Computes MMD between two sets of features using polynomial kernels. It
+    performs a number of repetitions of subset sampling without replacement.
+
+    Args:
+        codes_g (Tensor): Feature activations of generated images.
+        codes_r (Tensor): Feature activations of real images.
+        n_subsets (int): The number of subsets.
+        subset_size (int): The number of samples in each subset.
+        ret_var (bool): If ``True``, returns both mean and variance of MMDs,
+            otherwise only returns the mean.
+    Returns:
+        (tuple):
+          - mmds (Tensor): Mean of MMDs.
+          - mmd_vars (Tensor): Variance of MMDs.
+    """
+    mmds = np.zeros(n_subsets)
+    if ret_var:
+        mmd_vars = np.zeros(n_subsets)
+    choice = np.random.choice
+
+    if subset_size is None:
+        subset_size = min(len(codes_r), len(codes_r))
+        print("Subset size not provided, "
+              "setting it to the data size ({}).".format(subset_size))
+    if subset_size > len(codes_g) or subset_size > len(codes_r):
+        subset_size = min(len(codes_r), len(codes_r))
+        warnings.warn(
+            "Subset size is large than the actual data size, "
+            "setting it to the data size ({}).".format(subset_size))
+
+    for i in range(n_subsets):
+        g = codes_g[choice(len(codes_g), subset_size, replace=False)]
+        r = codes_r[choice(len(codes_r), subset_size, replace=False)]
+        o = _polynomial_mmd(g, r, **kernel_args, ret_var=ret_var)
+        if ret_var:
+            # noinspection PyUnboundLocalVariable
+            mmds[i], mmd_vars[i] = o
+        else:
+            mmds[i] = o
+    return {'KID': mmds.mean()}
+
+
+def _polynomial_kernel(X, Y=None, degree=3, gamma=None, coef0=1.):
+    r"""Compute the polynomial kernel between X and Y"""
+    if gamma is None:
+        gamma = 1.0 / X.shape[1]
+    if Y is None:
+        Y = X
+
+    # K = safe_sparse_dot(X, Y.T, dense_output=True)
+    K = torch.matmul(X, Y.t())
+    K *= gamma
+    K += coef0
+    K = K ** degree
+    return K
+
+
+def _polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1,
+                    ret_var=True):
+    r"""Computes MMD between two sets of features using polynomial kernels. It
+    performs a number of repetitions of subset sampling without replacement.
+
+    Args:
+        codes_g (torch.Tensor): Feature activations of generated images.
+        codes_r (torch.Tensor): Feature activations of real images.
+        degree (int): The degree of the polynomial kernel.
+        gamma (float or None): Scale of the polynomial kernel.
+        coef0 (float or None): Bias of the polynomial kernel.
+        ret_var (bool): If ``True``, returns both mean and variance of MMDs,
+            otherwise only returns the mean.
+    Returns:
+        (tuple):
+          - mmds (torch.Tensor): Mean of MMDs.
+          - mmd_vars (torch.Tensor): Variance of MMDs.
+    """
+    # use  k(x, y) = (gamma <x, y> + coef0)^degree
+    # default gamma is 1 / dim
+    X = codes_g
+    Y = codes_r
+
+    # with warnings.catch_warnings():
+    #     warnings.simplefilter('ignore')
+    K_XX = _polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
+    K_YY = _polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
+    K_XY = _polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)
+
+    return _mmd2_and_variance(K_XX, K_XY, K_YY, ret_var=ret_var)
+
+
+def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
+                       mmd_est='unbiased', ret_var=True):
+    r"""Based on
+    https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
+    but changed to not compute the full kernel matrix at once
+    """
+
+    m = K_XX.shape[0]
+    assert K_XX.shape == (m, m)
+    assert K_XY.shape == (m, m)
+    assert K_YY.shape == (m, m)
+    var_at_m = m
+
+    # Get the various sums of kernels that we'll use
+    # Kts drop the diagonal, but we don't need to compute them explicitly
+    if unit_diagonal:
+        diag_X = diag_Y = 1
+        sum_diag_X = sum_diag_Y = m
+        sum_diag2_X = sum_diag2_Y = m
+    else:
+        diag_X = torch.diagonal(K_XX)
+        diag_Y = torch.diagonal(K_YY)
+
+        sum_diag_X = diag_X.sum()
+        sum_diag_Y = diag_Y.sum()
+
+        sum_diag2_X = _sqn(diag_X)
+        sum_diag2_Y = _sqn(diag_Y)
+
+    Kt_XX_sums = K_XX.sum(dim=1) - diag_X
+    Kt_YY_sums = K_YY.sum(dim=1) - diag_Y
+    K_XY_sums_0 = K_XY.sum(dim=0)
+    K_XY_sums_1 = K_XY.sum(dim=1)
+
+    Kt_XX_sum = Kt_XX_sums.sum()
+    Kt_YY_sum = Kt_YY_sums.sum()
+    K_XY_sum = K_XY_sums_0.sum()
+
+    if mmd_est == 'biased':
+        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
+                + (Kt_YY_sum + sum_diag_Y) / (m * m)
+                - 2 * K_XY_sum / (m * m))
+    else:
+        assert mmd_est in {'unbiased', 'u-statistic'}
+        mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1))
+        if mmd_est == 'unbiased':
+            mmd2 -= 2 * K_XY_sum / (m * m)
+        else:
+            mmd2 -= 2 * (K_XY_sum - torch.trace(K_XY)) / (m * (m - 1))
+
+    if not ret_var:
+        return mmd2
+
+    Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
+    Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
+    K_XY_2_sum = _sqn(K_XY)
+
+    dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
+    dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)
+
+    m1 = m - 1
+    m2 = m - 2
+    zeta1_est = (
+        1 / (m * m1 * m2) *
+        (_sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
+        - 1 / (m * m1) ** 2 * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2)
+        + 1 / (m * m * m1) * (
+            _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
+        - 2 / m ** 4 * K_XY_sum ** 2
+        - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
+        + 2 / (m ** 3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
+    )
+    zeta2_est = (
+        1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
+        - 1 / (m * m1) ** 2 * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2)
+        + 2 / (m * m) * K_XY_2_sum
+        - 2 / m ** 4 * K_XY_sum ** 2
+        - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
+        + 4 / (m ** 3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
+    )
+    var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
+               + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)
+
+    return mmd2.cpu().numpy(), var_est.cpu().numpy()
+
+
+def _sqn(arr):
+    r"""Squared norm."""
+    flat = arr.view(-1)
+    return flat.dot(flat)
diff --git a/imaginaire/evaluation/knn.py b/imaginaire/evaluation/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..48b1cf502dc01e89e67f021572df9e94c71938ef
--- /dev/null
+++ b/imaginaire/evaluation/knn.py
@@ -0,0 +1,35 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+
+from imaginaire.evaluation.common import compute_nn
+
+
+def _get_1nn_acc(data_x, data_y, k=1):
+    device = data_x.device
+    n0 = data_x.size(0)
+    n1 = data_y.size(0)
+    data_all = torch.cat((data_x, data_y), dim=0)
+    val, idx = compute_nn(data_all, k)
+    label = torch.cat((torch.ones(n0, device=device),
+                       torch.zeros(n1, device=device)))
+
+    count = torch.zeros(n0 + n1, device=device)
+    for i in range(0, k):
+        count = count + label.index_select(0, idx[:, i])
+    pred = torch.ge(count, (float(k) / 2) *
+                    torch.ones(n0 + n1, device=device)).float()
+
+    tp = (pred * label).sum()
+    fp = (pred * (1 - label)).sum()
+    fn = ((1 - pred) * label).sum()
+    tn = ((1 - pred) * (1 - label)).sum()
+    acc_r = (tp / (tp + fn)).item()
+    acc_f = (tn / (tn + fp)).item()
+    acc = torch.eq(label, pred).float().mean().item()
+
+    return {'1NN_acc': acc,
+            '1NN_acc_real': acc_r,
+            '1NN_acc_fake': acc_f}
diff --git a/imaginaire/evaluation/lpips.py b/imaginaire/evaluation/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b10d77ad739151347ee2c45e6417145206cdbc
--- /dev/null
+++ b/imaginaire/evaluation/lpips.py
@@ -0,0 +1,153 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from collections import namedtuple
+
+import torch
+from torch import nn, distributed as dist
+import torchvision.models as tv
+from torch.distributed import barrier
+
+from imaginaire.utils.distributed import is_local_master
+
+
+def get_lpips_model():
+    if dist.is_initialized() and not is_local_master():
+        # Make sure only the first process in distributed training downloads the model, and the others use the cache.
+        barrier()
+
+    model = LPIPSNet().cuda()
+
+    if dist.is_initialized() and is_local_master():
+        # Make sure only the first process in distributed training downloads the model, and the others use the cache.
+        barrier()
+    return model
+
+
+# Learned perceptual network, modified from https://github.com/richzhang/PerceptualSimilarity
+
+def normalize_tensor(in_feat, eps=1e-5):
+    norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True) + eps)
+    return in_feat / (norm_factor + eps)
+
+
+class NetLinLayer(nn.Module):
+    """ A single linear layer used as placeholder for LPIPS learnt weights """
+
+    def __init__(self, dim):
+        super(NetLinLayer, self).__init__()
+        self.weight = nn.Parameter(torch.zeros(1, dim, 1, 1))
+
+    def forward(self, inp):
+        out = self.weight * inp
+        return out
+
+
+class ScalingLayer(nn.Module):
+    # For rescaling the input to vgg16
+    def __init__(self):
+        super(ScalingLayer, self).__init__()
+        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+    def forward(self, inp):
+        return (inp - self.shift) / self.scale
+
+
+class LPIPSNet(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.model = LPNet()
+
+    @torch.no_grad()
+    def forward(self, fake_images, fake_images_another, align_corners=True):
+        features, shape = self._forward_single(fake_images)
+        features_another, _ = self._forward_single(fake_images_another)
+        result = 0
+        for i, g_feat in enumerate(features):
+            cur_diff = torch.sum((g_feat - features_another[i]) ** 2, dim=1) / (shape[i] ** 2)
+            result += cur_diff
+        return result
+
+    def _forward_single(self, images):
+        return self.model(torch.clamp(images, 0, 1))
+
+
+class LPNet(nn.Module):
+    def __init__(self):
+        super(LPNet, self).__init__()
+
+        self.scaling_layer = ScalingLayer()
+        self.net = vgg16(pretrained=True, requires_grad=False)
+        self.L = 5
+        dims = [64, 128, 256, 512, 512]
+        self.lins = nn.ModuleList([NetLinLayer(dims[i]) for i in range(self.L)])
+
+        weights = torch.hub.load_state_dict_from_url(
+            'https://github.com/niopeng/CAM-Net/raw/main/code/models/weights/v0.1/vgg.pth'
+        )
+        for i in range(self.L):
+            self.lins[i].weight.data = torch.sqrt(weights["lin%d.model.1.weight" % i])
+
+    def forward(self, in0, avg=False):
+        in0 = 2 * in0 - 1
+        in0_input = self.scaling_layer(in0)
+        outs0 = self.net.forward(in0_input)
+        feats0 = {}
+        shapes = []
+        res = []
+
+        for kk in range(self.L):
+            feats0[kk] = normalize_tensor(outs0[kk])
+
+        if avg:
+            res = [self.lins[kk](feats0[kk]).mean([2, 3], keepdim=False) for kk in range(self.L)]
+        else:
+            for kk in range(self.L):
+                cur_res = self.lins[kk](feats0[kk])
+                shapes.append(cur_res.shape[-1])
+                res.append(cur_res.reshape(cur_res.shape[0], -1))
+
+        return res, shapes
+
+
+class vgg16(torch.nn.Module):
+    def __init__(self, requires_grad=False, pretrained=True):
+        super(vgg16, self).__init__()
+        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
+        self.slice1 = torch.nn.Sequential()
+        self.slice2 = torch.nn.Sequential()
+        self.slice3 = torch.nn.Sequential()
+        self.slice4 = torch.nn.Sequential()
+        self.slice5 = torch.nn.Sequential()
+        self.N_slices = 5
+        for x in range(4):
+            self.slice1.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(4, 9):
+            self.slice2.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(9, 16):
+            self.slice3.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(16, 23):
+            self.slice4.add_module(str(x), vgg_pretrained_features[x])
+        for x in range(23, 30):
+            self.slice5.add_module(str(x), vgg_pretrained_features[x])
+        if not requires_grad:
+            for param in self.parameters():
+                param.requires_grad = False
+
+    def forward(self, x):
+        h = self.slice1(x)
+        h_relu1_2 = h
+        h = self.slice2(h)
+        h_relu2_2 = h
+        h = self.slice3(h)
+        h_relu3_3 = h
+        h = self.slice4(h)
+        h_relu4_3 = h
+        h = self.slice5(h)
+        h_relu5_3 = h
+        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+
+        return out
diff --git a/imaginaire/evaluation/msid.py b/imaginaire/evaluation/msid.py
new file mode 100644
index 0000000000000000000000000000000000000000..c59a39dd38fe29d8d463de8348eb3749189c6b2b
--- /dev/null
+++ b/imaginaire/evaluation/msid.py
@@ -0,0 +1,375 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+# flake8: noqa
+from scipy.sparse import lil_matrix, diags, eye
+import math
+
+import numpy as np
+import torch
+
+EPSILON = 1e-6
+NORMALIZATION = 1e6
+
+
+def _get_msid(x, y, ts=np.logspace(-1, 1, 256), k=5, m=10, niters=100,
+              rademacher=False, graph_builder='full',
+              msid_mode='max', normalized_laplacian=True, normalize='empty'):
+    """
+    Compute the msid score between two samples, x and y
+    Arguments:
+        x: x samples
+        y: y samples
+        ts: temperature values
+        k: number of neighbours for graph construction
+        m: Lanczos steps in SLQ
+        niters: number of starting random vectors for SLQ
+        rademacher: if True, sample random vectors from Rademacher
+            distributions, else sample standard normal distribution
+        graph_builder: if 'kgraph', uses faster graph construction (options:
+            'sparse', 'kgraph')
+        msid_mode: 'l2' to compute the l2 norm of the distance between `msid1`
+            and `msid2`; 'max' to find the maximum abosulute difference between
+            two  descriptors over temperature
+        normalized_laplacian: if True, use normalized Laplacian
+        normalize: 'empty' for average heat kernel (corresponds to the empty
+            graph normalization of NetLSD), 'complete' for the complete, 'er'
+            for erdos-renyi normalization, 'none' for no normalization
+    Returns:
+        msid_score: the scalar value of the distance between discriptors
+    """
+    normed_msidx = msid_descriptor(x, ts, k, m, niters, rademacher,
+                                   graph_builder, normalized_laplacian,
+                                   normalize)
+    normed_msidy = msid_descriptor(y, ts, k, m, niters, rademacher,
+                                   graph_builder, normalized_laplacian,
+                                   normalize)
+
+    c = np.exp(-2 * (ts + 1 / ts))
+
+    if msid_mode == 'l2':
+        score = np.linalg.norm(normed_msidx - normed_msidy)
+    elif msid_mode == 'max':
+        score = np.amax(c * np.abs(normed_msidx - normed_msidy))
+    else:
+        raise Exception('Use either l2 or max mode.')
+
+    return {'IMD': score}
+
+
+def msid_descriptor(x, ts=np.logspace(-1, 1, 256), k=5, m=10, niters=100,
+                    rademacher=False, graph_builder='full',
+                    normalized_laplacian=True, normalize='empty'):
+    """
+    Compute the msid descriptor for a single sample x
+    Arguments:
+        x: x samples
+        ts: temperature values
+        k: number of neighbours for graph construction
+        m: Lanczos steps in SLQ
+        niters: number of starting random vectors for SLQ
+        rademacher: if True, sample random vectors from Rademacher
+            distributions, else sample standard normal distribution
+        graph_builder: if 'kgraph', uses faster graph construction
+            (options: '
+            rse', 'kgraph')
+        normalized_laplacian: if True, use normalized Laplacian
+        normalize: 'empty' for average heat kernel (corresponds to the empty
+            graph normalization of NetLSD), 'complete' for the complete, 'er'
+            for erdos-renyi normalization, 'none' for no normalization
+    Returns:
+        normed_msidx: normalized msid descriptor
+    """
+    Lx = _build_graph(x, k, graph_builder, normalized_laplacian)
+
+    nx = Lx.shape[0]
+    msidx = slq_red_var(Lx, m, niters, ts, rademacher)
+
+    normed_msidx = _normalize_msid(msidx, normalize, nx, k, ts) * NORMALIZATION
+
+    return normed_msidx
+
+
+def _build_graph(data, k=5, graph_builder='full', normalized=True):
+    """
+    Return Laplacian from data or load preconstructed from path
+
+    Arguments:
+        data: samples
+        k: number of neighbours for graph construction
+        graph_builder: if 'kgraph', use faster graph construction
+        normalized: if True, use nnormalized Laplacian
+    Returns:
+        L: Laplacian of the graph constructed with data
+    """
+    if graph_builder == 'sparse':
+        A = construct_graph_sparse(data.cpu().numpy(), k)
+    elif graph_builder == 'kgraph':
+        A = construct_graph_kgraph(data.cpu().numpy(), k)
+    elif graph_builder == 'full':
+        A = lil_matrix(construct_graph(data, k).cpu().numpy()).tocsr()
+    else:
+        raise Exception('Please specify graph builder: sparse or kgraph.')
+    A = (A + A.T) / 2
+    A.data = np.ones(A.data.shape)
+    L = _laplacian_sparse(A, normalized)
+    return L
+
+
+def _normalize_msid(msid, normalization, n, k, ts):
+    normed_msid = msid.copy()
+    if normalization == 'empty':
+        normed_msid /= n
+    elif normalization == 'complete':
+        normed_msid /= (1 + (n - 1) * np.exp(-(1 + 1 / (n - 1)) * ts))
+    elif normalization == 'er':
+        xs = np.linspace(0, 1, n)
+        er_spectrum = 4 / np.sqrt(k) * xs + 1 - 2 / np.sqrt(k)
+        er_msid = np.exp(-np.outer(ts, er_spectrum)).sum(-1)
+        normed_msid = normed_msid / (er_msid + EPSILON)
+    elif normalization == 'none' or normalization is None:
+        pass
+    else:
+        raise ValueError('Unknown normalization parameter!')
+    return normed_msid
+
+
+def _lanczos_m(A, m, nv, rademacher, SV=None):
+    """
+    Lanczos algorithm computes symmetric m x m tridiagonal matrix T and
+    matrix V with orthogonal rows constituting the basis of the Krylov
+    subspace K_m(A, x), where x is an arbitrary starting unit vector. This
+    implementation parallelizes `nv` starting vectors.
+
+    Arguments:
+        m: number of Lanczos steps
+        nv: number of random vectors
+        rademacher: True to use Rademacher distribution,
+                    False - standard normal for random vectors
+        SV: specified starting vectors
+
+    Returns: T: a nv x m x m tensor, T[i, :, :] is the ith symmetric
+    tridiagonal matrix V: a n x m x nv tensor, V[:, :, i] is the ith matrix
+    with orthogonal rows
+    """
+    orthtol = 1e-5
+    if type(SV) != np.ndarray:
+        if rademacher:
+            SV = np.sign(np.random.randn(A.shape[0], nv))
+        else:
+            SV = np.random.randn(A.shape[0],
+                                 nv)  # init random vectors in columns: n x nv
+    V = np.zeros((SV.shape[0], m, nv))
+    T = np.zeros((nv, m, m))
+
+    np.divide(SV, np.linalg.norm(SV, axis=0), out=SV)  # normalize each column
+    V[:, 0, :] = SV
+
+    w = A.dot(SV)
+    alpha = np.einsum('ij,ij->j', w, SV)
+    w -= alpha[None, :] * SV
+    beta = np.einsum('ij,ij->j', w, w)
+    np.sqrt(beta, beta)
+
+    T[:, 0, 0] = alpha
+    T[:, 0, 1] = beta
+    T[:, 1, 0] = beta
+
+    np.divide(w, beta[None, :], out=w)
+    V[:, 1, :] = w
+    t = np.zeros((m, nv))
+
+    for i in range(1, m):
+        SVold = V[:, i - 1, :]
+        SV = V[:, i, :]
+
+        w = A.dot(SV)  # sparse @ dense
+        w -= beta[None, :] * SVold  # n x nv
+        np.einsum('ij,ij->j', w, SV, out=alpha)
+
+        T[:, i, i] = alpha
+
+        if i < m - 1:
+            w -= alpha[None, :] * SV  # n x nv
+            # reortho
+            np.einsum('ijk,ik->jk', V, w, out=t)
+            w -= np.einsum('ijk,jk->ik', V, t)
+            np.einsum('ij,ij->j', w, w, out=beta)
+            np.sqrt(beta, beta)
+            np.divide(w, beta[None, :], out=w)
+
+            T[:, i, i + 1] = beta
+            T[:, i + 1, i] = beta
+
+            # more reotho
+            innerprod = np.einsum('ijk,ik->jk', V, w)
+            reortho = False
+            for _ in range(100):
+                if not (innerprod > orthtol).sum():
+                    reortho = True
+                    break
+                np.einsum('ijk,ik->jk', V, w, out=t)
+                w -= np.einsum('ijk,jk->ik', V, t)
+                np.divide(w, np.linalg.norm(w, axis=0)[None, :], out=w)
+                innerprod = np.einsum('ijk,ik->jk', V, w)
+
+            V[:, i + 1, :] = w
+
+            if (np.abs(beta) > 1e-6).sum() == 0 or not reortho:
+                break
+    return T, V
+
+
+def _slq(A, m, niters, rademacher):
+    """
+    Compute the trace of matrix exponential
+
+    Arguments:
+        A: square matrix in trace(exp(A))
+        m: number of Lanczos steps
+        niters: number of quadratures (also, the number of random vectors in the
+            hutchinson trace estimator)
+        rademacher: True to use Rademacher distribution, False - standard normal
+            for random vectors in Hutchinson
+    Returns: trace: estimate of trace of matrix exponential
+    """
+    T, _ = _lanczos_m(A, m, niters, rademacher)
+    eigvals, eigvecs = np.linalg.eigh(T)
+    expeig = np.exp(eigvals)
+    sqeigv1 = np.power(eigvecs[:, 0, :], 2)
+    trace = A.shape[-1] * (expeig * sqeigv1).sum() / niters
+    return trace
+
+
+def _slq_ts(A, m, niters, ts, rademacher):
+    """
+    Compute the trace of matrix exponential
+
+    Arguments:
+        A: square matrix in trace(exp(-t*A)), where t is temperature
+        m: number of Lanczos steps
+        niters: number of quadratures (also, the number of random vectors in the
+            hutchinson trace estimator)
+        ts: an array with temperatures
+        rademacher: True to use Rademacher distribution, False - standard normal
+            for random vectors in Hutchinson
+    Returns:
+        trace: estimate of trace of matrix exponential across temperatures `ts`
+    """
+    T, _ = _lanczos_m(A, m, niters, rademacher)
+    eigvals, eigvecs = np.linalg.eigh(T)
+    expeig = np.exp(-np.outer(ts, eigvals)).reshape(ts.shape[0], niters, m)
+    sqeigv1 = np.power(eigvecs[:, 0, :], 2)
+    traces = A.shape[-1] * (expeig * sqeigv1).sum(-1).mean(-1)
+    return traces
+
+
+def _slq_ts_fs(A, m, niters, ts, rademacher, fs):
+    """
+    Compute the trace of matrix functions
+
+    Arguments:
+        A: square matrix in trace(exp(-t*A)), where t is temperature
+        m: number of Lanczos steps
+        niters: number of quadratures (also, the number of random vectors in the
+            hutchinson trace estimator)
+        ts: an array with temperatures
+        rademacher: True to use Rademacher distribution, else - standard normal
+            for random vectors in Hutchinson
+        fs: a list of functions
+    Returns:
+        traces: estimate of traces for each of the functions in fs
+    """
+    T, _ = _lanczos_m(A, m, niters, rademacher)
+    eigvals, eigvecs = np.linalg.eigh(T)
+    traces = np.zeros((len(fs), len(ts)))
+    for i, f in enumerate(fs):
+        expeig = f(-np.outer(ts, eigvals)).reshape(ts.shape[0], niters, m)
+        sqeigv1 = np.power(eigvecs[:, 0, :], 2)
+        traces[i, :] = A.shape[-1] * (expeig * sqeigv1).sum(-1).mean(-1)
+    return traces
+
+
+def slq_red_var(A, m, niters, ts, rademacher):
+    """
+    Compute the trace of matrix exponential with reduced variance
+
+    Arguments:
+        A: square matrix in trace(exp(-t*A)), where t is temperature
+        m: number of Lanczos steps
+        niters: number of quadratures (also, the number of random vectors in the
+            hutchinson trace estimator)
+        ts: an array with temperatures
+    Returns:
+        traces: estimate of trace for each temperature value in `ts`
+    """
+    fs = [np.exp, lambda x: x]
+
+    traces = _slq_ts_fs(A, m, niters, ts, rademacher, fs)
+    subee = traces[0, :] - traces[1, :] / np.exp(ts)
+    sub = - ts * A.shape[0] / np.exp(ts)
+    return subee + sub
+
+
+def np_euc_cdist(data):
+    dd = np.sum(data * data, axis=1)
+    dist = -2 * np.dot(data, data.T)
+    dist += dd + dd[:, np.newaxis]
+    np.fill_diagonal(dist, 0)
+    np.sqrt(dist, dist)
+    return dist
+
+
+def construct_graph_sparse(data, k):
+    n = len(data)
+    spmat = lil_matrix((n, n))
+    dd = np.sum(data * data, axis=1)
+
+    for i in range(n):
+        dists = dd - 2 * data[i, :].dot(data.T)
+        inds = np.argpartition(dists, k + 1)[:k + 1]
+        inds = inds[inds != i]
+        spmat[i, inds] = 1
+
+    return spmat.tocsr()
+
+
+def construct_graph_kgraph(data, k):
+    raise NotImplementedError
+    #
+    # n = len(data)
+    # spmat = lil_matrix((n, n))
+    # index = pykgraph.KGraph(data, 'euclidean')
+    # index.build(reverse=0, K=2 * k + 1, L=2 * k + 50)
+    # result = index.search(data, K=k + 1)[:, 1:]
+    # spmat[np.repeat(np.arange(n), k, 0), result.ravel()] = 1
+    # return spmat.tocsr()
+
+
+def construct_graph(input_features, k, num_splits=10):
+    num_samples = input_features.shape[0]
+    indices = []
+    for i in range(num_splits):
+        batch_size = math.ceil(num_samples / num_splits)
+        start_idx = i * batch_size
+        end_idx = min((i + 1) * batch_size, num_samples)
+        dist = torch.cdist(input_features[start_idx:end_idx],
+                           input_features)
+        indices.append(torch.topk(dist, k + 1, dim=-1, largest=False)[1].cpu())
+    indices = torch.cat(indices, dim=0)
+    graph = torch.zeros(num_samples, num_samples, device=indices.device)
+    graph.scatter_(1, indices, 1)
+    graph -= torch.eye(num_samples, device=graph.device)
+    return graph
+
+
+def _laplacian_sparse(A, normalized=True):
+    D = A.sum(1).A1
+    if normalized:
+        Dsqrt = diags(1 / np.sqrt(D))
+        L = eye(A.shape[0]) - Dsqrt.dot(A).dot(Dsqrt)
+    else:
+        L = diags(D) - A
+    return L
diff --git a/imaginaire/evaluation/prdc.py b/imaginaire/evaluation/prdc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ba2abba01f6e4b8e49b4b979ebe0aa93e4f9b00
--- /dev/null
+++ b/imaginaire/evaluation/prdc.py
@@ -0,0 +1,124 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""
+Modified from https://github.com/clovaai/generative-evaluation-prdc
+Copyright (c) 2020-present NAVER Corp.
+MIT license
+"""
+import os
+
+import torch
+
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+
+from .common import load_or_compute_activations, compute_pairwise_distance, \
+    compute_nn
+
+
+@torch.no_grad()
+def compute_prdc(prdc_path, data_loader, net_G,
+                 key_real='images', key_fake='fake_images',
+                 real_act=None, fake_act=None,
+                 sample_size=None, save_act=True, k=10, **kwargs):
+    r"""Compute precision diversity curve
+
+    Args:
+
+    """
+    print('Computing PRDC.')
+    act_path = os.path.join(
+        os.path.dirname(prdc_path), 'activations_real.npy'
+    ) if save_act else None
+
+    # Get the fake activations.
+    if fake_act is None:
+        fake_act = load_or_compute_activations(None,
+                                               data_loader,
+                                               key_real, key_fake, net_G,
+                                               sample_size=sample_size,
+                                               **kwargs)
+    else:
+        print(f"Using precomputed activations of size {fake_act.shape}.")
+
+    # Get the ground truth activations.
+    if real_act is None:
+        real_act = load_or_compute_activations(act_path,
+                                               data_loader,
+                                               key_real, key_fake, None,
+                                               sample_size=sample_size,
+                                               **kwargs)
+    else:
+        print(f"Using precomputed activations of size {real_act.shape}.")
+
+    if is_master():
+        prdc_data = _get_prdc(real_act, fake_act, k)
+        return \
+            prdc_data['precision'], prdc_data['recall'], \
+            prdc_data['density'], prdc_data['coverage']
+    else:
+        return None, None, None, None
+
+
+def get_kth_value(unsorted, k, dim=-1):
+    r"""
+
+    Args:
+        unsorted: numpy.ndarray of any dimensionality.
+        k: int
+    Returns:
+        kth values along the designated axis.
+    """
+    indices = torch.topk(unsorted, k, dim=dim, largest=False)[1]
+    k_smallests = torch.gather(unsorted, dim=dim, index=indices)
+    kth_values = k_smallests.max(dim=dim)[0]
+    return kth_values
+
+
+def _get_prdc(real_features, fake_features, nearest_k):
+    r"""
+    Computes precision, recall, density, and coverage given two manifolds.
+
+    Args:
+        real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
+        fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
+        nearest_k: int.
+    Returns:
+        dict of precision, recall, density, and coverage.
+    """
+    real_nearest_neighbour_distances, _ = compute_nn(
+        real_features, nearest_k)
+    real_nearest_neighbour_distances = \
+        real_nearest_neighbour_distances.max(dim=-1)[0].cpu()
+    fake_nearest_neighbour_distances, _ = compute_nn(
+        fake_features, nearest_k)
+    fake_nearest_neighbour_distances = \
+        fake_nearest_neighbour_distances.max(dim=-1)[0].cpu()
+    distance_real_fake = compute_pairwise_distance(
+        real_features, fake_features)
+
+    precision = (
+            distance_real_fake <
+            torch.unsqueeze(real_nearest_neighbour_distances, dim=1)
+    ).any(dim=0).float().mean().item()
+
+    recall = (
+            distance_real_fake <
+            torch.unsqueeze(fake_nearest_neighbour_distances, dim=0)
+    ).any(dim=1).float().mean().item()
+
+    density = (1. / float(nearest_k)) * (
+            distance_real_fake <
+            torch.unsqueeze(real_nearest_neighbour_distances, dim=1)
+    ).sum(dim=0).float().mean().item()
+
+    # noinspection PyUnresolvedReferences
+    coverage = (
+            distance_real_fake.min(dim=1)[0] <
+            real_nearest_neighbour_distances
+    ).float().mean().item()
+
+    return dict(precision=precision, recall=recall,
+                density=density, coverage=coverage)
diff --git a/imaginaire/evaluation/pretrained.py b/imaginaire/evaluation/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..b99c8f19cbb7687ac5e88b28939e83475e458971
--- /dev/null
+++ b/imaginaire/evaluation/pretrained.py
@@ -0,0 +1,232 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+"""
+Modified from
+https://github.com/mseitzer/pytorch-fid
+
+Code adapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
+of Tensorflow
+Copyright 2018 Institute of Bioinformatics, JKU Linz
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+   http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+try:
+    from torchvision.models.utils import load_state_dict_from_url
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+from torchvision.models import inception, inception_v3, vgg16
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases' \
+                  '/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
+
+
+class SwAV(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.model = torch.hub.load('facebookresearch/swav', 'resnet50',
+                                    pretrained=True)
+        self.model.fc = torch.nn.Sequential()
+
+    def forward(self, x, align_corners=True):
+        y = self.model(F.interpolate(
+            x, size=(224, 224), mode='bicubic', align_corners=align_corners))
+        return y
+
+
+class Vgg16(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.model = vgg16(pretrained=True, init_weights=False)
+        self.model.classifier = torch.nn.Sequential(
+            *[self.model.classifier[i] for i in range(4)]
+        )
+
+    def forward(self, x, align_corners=True):
+        y = self.model(F.interpolate(
+            x, size=(224, 224), mode='bicubic', align_corners=align_corners))
+        return y
+
+
+class InceptionV3(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.model = inception_v3(transform_input=False,
+                                  pretrained=True,
+                                  init_weights=False)
+        self.model.fc = torch.nn.Sequential()
+
+    def forward(self, x, align_corners=True):
+        y = self.model(F.interpolate(
+            x, size=(299, 299), mode='bicubic', align_corners=align_corners))
+        return y
+
+
+class TFInceptionV3(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.model = inception_v3(transform_input=False,
+                                  num_classes=1008,
+                                  aux_logits=False,
+                                  pretrained=False,
+                                  init_weights=False)
+        self.model.Mixed_5b = FIDInceptionA(192, pool_features=32)
+        self.model.Mixed_5c = FIDInceptionA(256, pool_features=64)
+        self.model.Mixed_5d = FIDInceptionA(288, pool_features=64)
+        self.model.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+        self.model.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+        self.model.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+        self.model.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+        self.model.Mixed_7b = FIDInceptionE_1(1280)
+        self.model.Mixed_7c = FIDInceptionE_2(2048)
+
+        state_dict = load_state_dict_from_url(
+            FID_WEIGHTS_URL, progress=True, map_location='cpu'
+        )
+        self.model.load_state_dict(state_dict)
+        self.model.fc = torch.nn.Sequential()
+
+    def forward(self, x, align_corners=True):
+        y = self.model(F.interpolate(
+            x, size=(299, 299), mode='bicubic', align_corners=align_corners))
+        return y
+
+
+class FIDInceptionA(inception.InceptionA):
+    """InceptionA block patched for FID computation"""
+
+    def __init__(self, in_channels, pool_features):
+        super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch5x5 = self.branch5x5_1(x)
+        branch5x5 = self.branch5x5_2(branch5x5)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(inception.InceptionC):
+    """InceptionC block patched for FID computation"""
+
+    def __init__(self, in_channels, channels_7x7):
+        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch7x7 = self.branch7x7_1(x)
+        branch7x7 = self.branch7x7_2(branch7x7)
+        branch7x7 = self.branch7x7_3(branch7x7)
+
+        branch7x7dbl = self.branch7x7dbl_1(x)
+        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(inception.InceptionE):
+    """First InceptionE block patched for FID computation"""
+
+    def __init__(self, in_channels):
+        super(FIDInceptionE_1, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(inception.InceptionE):
+    """Second InceptionE block patched for FID computation"""
+
+    def __init__(self, in_channels):
+        super(FIDInceptionE_2, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: The FID Inception model uses max pooling instead of average
+        # pooling. This is likely an error in this specific Inception
+        # implementation, as other Inception models use average pooling here
+        # (which matches the description in the paper).
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
diff --git a/imaginaire/evaluation/segmentation/__init__.py b/imaginaire/evaluation/segmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..63a71a20ebb0ac2940c091f064de5a029126af42
--- /dev/null
+++ b/imaginaire/evaluation/segmentation/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .common import get_segmentation_hist_model, get_miou,compute_hist
+
+__all__ = ['get_segmentation_hist_model', 'get_miou','compute_hist']
diff --git a/imaginaire/evaluation/segmentation/celebamask_hq.py b/imaginaire/evaluation/segmentation/celebamask_hq.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab95f325eb681bb0d80bac2c626494ee9e81b51b
--- /dev/null
+++ b/imaginaire/evaluation/segmentation/celebamask_hq.py
@@ -0,0 +1,130 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# https://github.com/switchablenorms/CelebAMask-HQ/tree/master/face_parsing
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Unet(nn.Module):
+    def __init__(
+            self,
+            feature_scale=4,
+            n_classes=19,
+            is_deconv=True,
+            in_channels=3,
+            is_batchnorm=True,
+            image_size=512,
+            use_dont_care=False
+    ):
+        super(Unet, self).__init__()
+        self.is_deconv = is_deconv
+        self.in_channels = in_channels
+        self.is_batchnorm = is_batchnorm
+        self.feature_scale = feature_scale
+        self.image_size = image_size
+        self.n_classes = n_classes
+        self.use_dont_care = use_dont_care
+
+        filters = [64, 128, 256, 512, 1024]
+        filters = [int(x / self.feature_scale) for x in filters]
+
+        # downsampling
+        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
+        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
+
+        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
+        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
+
+        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
+        self.maxpool3 = nn.MaxPool2d(kernel_size=2)
+
+        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
+        self.maxpool4 = nn.MaxPool2d(kernel_size=2)
+
+        self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
+
+        # upsampling
+        self.up_concat4 = unetUp(
+            filters[4], filters[3], self.is_deconv, self.is_batchnorm)
+        self.up_concat3 = unetUp(
+            filters[3], filters[2], self.is_deconv, self.is_batchnorm)
+        self.up_concat2 = unetUp(
+            filters[2], filters[1], self.is_deconv, self.is_batchnorm)
+        self.up_concat1 = unetUp(
+            filters[1], filters[0], self.is_deconv, self.is_batchnorm)
+
+        # final conv (without any concat)
+        self.final = nn.Conv2d(filters[0], n_classes, 1)
+
+    def forward(self, images, align_corners=True):
+        images = F.interpolate(
+            images, size=(self.image_size, self.image_size), mode='bicubic',
+            align_corners=align_corners
+        )
+
+        conv1 = self.conv1(images)
+        maxpool1 = self.maxpool1(conv1)
+        conv2 = self.conv2(maxpool1)
+        maxpool2 = self.maxpool2(conv2)
+        conv3 = self.conv3(maxpool2)
+        maxpool3 = self.maxpool3(conv3)
+        conv4 = self.conv4(maxpool3)
+        maxpool4 = self.maxpool4(conv4)
+        center = self.center(maxpool4)
+        up4 = self.up_concat4(conv4, center)
+        up3 = self.up_concat3(conv3, up4)
+        up2 = self.up_concat2(conv2, up3)
+        up1 = self.up_concat1(conv1, up2)
+        probs = self.final(up1)
+        pred = torch.argmax(probs, dim=1)
+        return pred
+
+
+class unetConv2(nn.Module):
+    def __init__(self, in_size, out_size, is_batchnorm):
+        super(unetConv2, self).__init__()
+
+        if is_batchnorm:
+            self.conv1 = nn.Sequential(
+                nn.Conv2d(in_size, out_size, 3, 1, 1),
+                nn.BatchNorm2d(out_size),
+                nn.ReLU(),
+            )
+            self.conv2 = nn.Sequential(
+                nn.Conv2d(out_size, out_size, 3, 1, 1),
+                nn.BatchNorm2d(out_size),
+                nn.ReLU(),
+            )
+        else:
+            self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1),
+                                       nn.ReLU())
+            self.conv2 = nn.Sequential(
+                nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU()
+            )
+
+    def forward(self, inputs):
+        outputs = self.conv1(inputs)
+        outputs = self.conv2(outputs)
+        return outputs
+
+
+class unetUp(nn.Module):
+    def __init__(self, in_size, out_size, is_deconv, is_batchnorm):
+        super(unetUp, self).__init__()
+        self.conv = unetConv2(in_size, out_size, is_batchnorm)
+        if is_deconv:
+            self.up = nn.ConvTranspose2d(
+                in_size, out_size, kernel_size=2, stride=2)
+        else:
+            self.up = nn.UpsamplingBilinear2d(scale_factor=2)
+
+    def forward(self, inputs1, inputs2):
+        outputs2 = self.up(inputs2)
+        offset = outputs2.size()[2] - inputs1.size()[2]
+        padding = 2 * [offset // 2, offset // 2]
+        outputs1 = F.pad(inputs1, padding)
+
+        return self.conv(torch.cat([outputs1, outputs2], 1))
diff --git a/imaginaire/evaluation/segmentation/cocostuff.py b/imaginaire/evaluation/segmentation/cocostuff.py
new file mode 100644
index 0000000000000000000000000000000000000000..601f0c28a3811f994df9a9375bd6c5fc08509d9d
--- /dev/null
+++ b/imaginaire/evaluation/segmentation/cocostuff.py
@@ -0,0 +1,48 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from torch import nn
+from torch.nn import functional as F
+import torch.hub
+
+
+class DeepLabV2(nn.Module):
+    def __init__(self, n_classes=182, image_size=512, use_dont_care=True):
+        super(DeepLabV2, self).__init__()
+        self.model = torch.hub.load(
+            "kazuto1011/deeplab-pytorch", "deeplabv2_resnet101",
+            pretrained=False, n_classes=182
+        )
+        state_dict = torch.hub.load_state_dict_from_url(
+            'https://github.com/kazuto1011/deeplab-pytorch/releases/download/'
+            'v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth',
+            map_location="cpu"
+        )
+        self.model.load_state_dict(state_dict)
+
+        self.image_size = image_size
+        # self.mean = torch.tensor([122.675, 116.669, 104.008], device="cuda")
+        self.mean = torch.tensor([104.008, 116.669, 122.675], device="cuda")
+        self.n_classes = n_classes
+        self.use_dont_care = use_dont_care
+
+    def forward(self, images, align_corners=True):
+        scale = self.image_size / max(images.shape[2:])
+        images = F.interpolate(
+            images, scale_factor=scale, mode='bilinear',
+            align_corners=align_corners
+        )
+        images = 255 * 0.5 * (images + 1)  # (-1, 1) -> (0, 255)
+        images = images.flip(1)  # RGB to BGR
+        images -= self.mean[None, :, None, None]
+        _, _, H, W = images.shape
+
+        logits = self.model(images)
+        logits = F.interpolate(
+            logits, size=(H, W), mode="bilinear",
+            align_corners=align_corners
+        )
+        probs = F.softmax(logits, dim=1)
+        pred = torch.argmax(probs, dim=1)
+        return pred
diff --git a/imaginaire/evaluation/segmentation/common.py b/imaginaire/evaluation/segmentation/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..78d90ae26db4747e5be859a8c1237a288e3146b7
--- /dev/null
+++ b/imaginaire/evaluation/segmentation/common.py
@@ -0,0 +1,92 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import boto3
+import torch
+from torch import nn, distributed as dist
+from torch.nn import functional as F
+
+from imaginaire.utils.distributed import is_local_master
+from imaginaire.utils.io import download_file_from_google_drive
+
+
+def get_segmentation_hist_model(dataset_name, aws_credentials=None):
+    if dist.is_initialized() and not is_local_master():
+        # Make sure only the first process in distributed training downloads
+        # the model, and the others will use the cache
+        # noinspection PyUnresolvedReferences
+        torch.distributed.barrier()
+
+    # Load the segmentation network.
+    if dataset_name == "celebamask_hq":
+        from imaginaire.evaluation.segmentation.celebamask_hq import Unet
+        seg_network = Unet()
+        os.makedirs(os.path.join(torch.hub.get_dir(), 'checkpoints'), exist_ok=True)
+        model_path = os.path.join(os.path.join(torch.hub.get_dir(), 'checkpoints'), "celebamask_hq.pt")
+        if not os.path.exists(model_path):
+            if aws_credentials is not None:
+                s3 = boto3.client('s3', **aws_credentials)
+                s3.download_file('lpi-poe', 'model_zoo/celebamask_hq.pt', model_path)
+            else:
+                download_file_from_google_drive("1o1m-eT38zNCIFldcRaoWcLvvBtY8S4W3", model_path)
+        state_dict = torch.load(model_path, map_location='cpu')
+        seg_network.load_state_dict(state_dict)
+    elif dataset_name == "cocostuff" or dataset_name == "getty":
+        from imaginaire.evaluation.segmentation.cocostuff import DeepLabV2
+        seg_network = DeepLabV2()
+    else:
+        print(f"No segmentation network for {dataset_name} was found.")
+        return None
+
+    if dist.is_initialized() and is_local_master():
+        # Make sure only the first process in distributed training downloads
+        # the model, and the others will use the cache
+        # noinspection PyUnresolvedReferences
+        torch.distributed.barrier()
+
+    if seg_network is not None:
+        seg_network = seg_network.to('cuda').eval()
+
+    return SegmentationHistModel(seg_network)
+
+
+class SegmentationHistModel(nn.Module):
+    def __init__(self, seg_network):
+        super().__init__()
+        self.seg_network = seg_network
+
+    def forward(self, data, fake_images, align_corners=True):
+        pred = self.seg_network(fake_images, align_corners=align_corners)
+        gt = data["segmaps"]
+        gt = gt * 255.0
+        gt = gt.long()
+        # print(fake_images.shape, fake_images.min(), fake_images.max())
+        # print(gt.shape, gt.min(), gt.max())
+        # exit()
+        return compute_hist(pred, gt, self.seg_network.n_classes, self.seg_network.use_dont_care)
+
+
+def compute_hist(pred, gt, n_classes, use_dont_care):
+    _, H, W = pred.size()
+    gt = F.interpolate(gt.float(), (H, W), mode="nearest").long().squeeze(1)
+    ignore_idx = n_classes if use_dont_care else -1
+    all_hist = []
+    for cur_pred, cur_gt in zip(pred, gt):
+        keep = torch.logical_not(cur_gt == ignore_idx)
+        merge = cur_pred[keep] * n_classes + cur_gt[keep]
+        hist = torch.bincount(merge, minlength=n_classes ** 2)
+        hist = hist.view((n_classes, n_classes))
+        all_hist.append(hist)
+    all_hist = torch.stack(all_hist)
+    return all_hist
+
+
+def get_miou(hist, eps=1e-8):
+    hist = hist.sum(0)
+    IOUs = torch.diag(hist) / (
+            torch.sum(hist, dim=0, keepdim=False) + torch.sum(hist, dim=1, keepdim=False) - torch.diag(hist) + eps)
+    mIOU = 100 * torch.mean(IOUs).item()
+    return {"seg_mIOU": mIOU}
diff --git a/imaginaire/generators/__init__.py b/imaginaire/generators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imaginaire/generators/__pycache__/__init__.cpython-38.pyc b/imaginaire/generators/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c29051fe61ab456c65828f39f1e1130d389be734
Binary files /dev/null and b/imaginaire/generators/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc b/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fa5fc572212d24f822b7bcdc1d936b848ea4adb
Binary files /dev/null and b/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc differ
diff --git a/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc b/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..205291ec52f52992a82b459886ef1b3455ee39f0
Binary files /dev/null and b/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc differ
diff --git a/imaginaire/generators/coco_funit.py b/imaginaire/generators/coco_funit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1f555fb91c9d8580ccf69e1e785c5b6c5a54aef
--- /dev/null
+++ b/imaginaire/generators/coco_funit.py
@@ -0,0 +1,194 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+
+from imaginaire.generators.funit import (MLP, ContentEncoder, Decoder,
+                                         StyleEncoder)
+
+
+class Generator(nn.Module):
+    r"""COCO-FUNIT Generator.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        r"""COCO-FUNIT Generator constructor.
+
+        Args:
+            gen_cfg (obj): Generator definition part of the yaml config file.
+            data_cfg (obj): Data definition part of the yaml config file.
+        """
+        super().__init__()
+        self.generator = COCOFUNITTranslator(**vars(gen_cfg))
+
+    def forward(self, data):
+        r"""In the FUNIT's forward pass, it generates a content embedding and
+        a style code from the content image, and a style code from the style
+        image. By mixing the content code and the style code from the content
+        image, we reconstruct the input image. By mixing the content code and
+        the style code from the style image, we have a translation output.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        content_a = self.generator.content_encoder(data['images_content'])
+        style_a = self.generator.style_encoder(data['images_content'])
+        style_b = self.generator.style_encoder(data['images_style'])
+        images_trans = self.generator.decode(content_a, style_b)
+        images_recon = self.generator.decode(content_a, style_a)
+
+        net_G_output = dict(images_trans=images_trans,
+                            images_recon=images_recon)
+        return net_G_output
+
+    def inference(self, data, keep_original_size=True):
+        r"""COCO-FUNIT inference.
+
+        Args:
+            data (dict): Training data at the current iteration.
+              - images_content (tensor): Content images.
+              - images_style (tensor): Style images.
+            a2b (bool): If ``True``, translates images from domain A to B,
+                otherwise from B to A.
+            keep_original_size (bool): If ``True``, output image is resized
+            to the input content image size.
+        """
+        content_a = self.generator.content_encoder(data['images_content'])
+        style_b = self.generator.style_encoder(data['images_style'])
+        output_images = self.generator.decode(content_a, style_b)
+        if keep_original_size:
+            height = data['original_h_w'][0][0]
+            width = data['original_h_w'][0][1]
+            # print('( H, W) = ( %d, %d)' % (height, width))
+            output_images = torch.nn.functional.interpolate(
+                output_images, size=[height, width])
+        file_names = data['key']['images_content'][0]
+        return output_images, file_names
+
+
+class COCOFUNITTranslator(nn.Module):
+    r"""COCO-FUNIT Generator architecture.
+
+    Args:
+        num_filters (int): Base filter numbers.
+        num_filters_mlp (int): Base filter number in the MLP module.
+        style_dims (int): Dimension of the style code.
+        usb_dims (int): Dimension of the universal style bias code.
+        num_res_blocks (int): Number of residual blocks at the end of the
+            content encoder.
+        num_mlp_blocks (int): Number of layers in the MLP module.
+        num_downsamples_content (int): Number of times we reduce
+            resolution by 2x2 for the content image.
+        num_downsamples_style (int): Number of times we reduce
+            resolution by 2x2 for the style image.
+        num_image_channels (int): Number of input image channels.
+        weight_norm_type (str): Type of weight normalization.
+            ``'none'``, ``'spectral'``, or ``'weight'``.
+    """
+
+    def __init__(self,
+                 num_filters=64,
+                 num_filters_mlp=256,
+                 style_dims=64,
+                 usb_dims=1024,
+                 num_res_blocks=2,
+                 num_mlp_blocks=3,
+                 num_downsamples_style=4,
+                 num_downsamples_content=2,
+                 num_image_channels=3,
+                 weight_norm_type='',
+                 **kwargs):
+        super().__init__()
+
+        self.style_encoder = StyleEncoder(num_downsamples_style,
+                                          num_image_channels,
+                                          num_filters,
+                                          style_dims,
+                                          'reflect',
+                                          'none',
+                                          weight_norm_type,
+                                          'relu')
+
+        self.content_encoder = ContentEncoder(num_downsamples_content,
+                                              num_res_blocks,
+                                              num_image_channels,
+                                              num_filters,
+                                              'reflect',
+                                              'instance',
+                                              weight_norm_type,
+                                              'relu')
+
+        self.decoder = Decoder(self.content_encoder.output_dim,
+                               num_filters_mlp,
+                               num_image_channels,
+                               num_downsamples_content,
+                               'reflect',
+                               weight_norm_type,
+                               'relu')
+
+        self.usb = torch.nn.Parameter(torch.randn(1, usb_dims))
+
+        self.mlp = MLP(style_dims,
+                       num_filters_mlp,
+                       num_filters_mlp,
+                       num_mlp_blocks,
+                       'none',
+                       'relu')
+
+        num_content_mlp_blocks = 2
+        num_style_mlp_blocks = 2
+        self.mlp_content = MLP(self.content_encoder.output_dim,
+                               style_dims,
+                               num_filters_mlp,
+                               num_content_mlp_blocks,
+                               'none',
+                               'relu')
+
+        self.mlp_style = MLP(style_dims + usb_dims,
+                             style_dims,
+                             num_filters_mlp,
+                             num_style_mlp_blocks,
+                             'none',
+                             'relu')
+
+    def forward(self, images):
+        r"""Reconstruct the input image by combining the computer content and
+        style code.
+
+        Args:
+            images (tensor): Input image tensor.
+        """
+        # reconstruct an image
+        content, style = self.encode(images)
+        images_recon = self.decode(content, style)
+        return images_recon
+
+    def encode(self, images):
+        r"""Encoder images to get their content and style codes.
+
+        Args:
+            images (tensor): Input image tensor.
+        """
+        style = self.style_encoder(images)
+        content = self.content_encoder(images)
+        return content, style
+
+    def decode(self, content, style):
+        r"""Generate images by combining their content and style codes.
+
+        Args:
+            content (tensor): Content code tensor.
+            style (tensor): Style code tensor.
+        """
+        content_style_code = content.mean(3).mean(2)
+        content_style_code = self.mlp_content(content_style_code)
+        batch_size = style.size(0)
+        usb = self.usb.repeat(batch_size, 1)
+        style = style.view(batch_size, -1)
+        style_in = self.mlp_style(torch.cat([style, usb], 1))
+        coco_style = style_in * content_style_code
+        coco_style = self.mlp(coco_style)
+        images = self.decoder(content, coco_style)
+        return images
diff --git a/imaginaire/generators/craft_2stage.py b/imaginaire/generators/craft_2stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee9c8bd42feb3fae32a96625f885a6752e5254d1
--- /dev/null
+++ b/imaginaire/generators/craft_2stage.py
@@ -0,0 +1,65 @@
+import torch
+from torch import nn
+import functools
+import torch.nn.functional as F
+
+import sys 
+sys.path.append(".") 
+from model import geometry_transform
+from imaginaire.utils.distributed import master_only_print as print
+from model.graphs.decoder import DeepLab
+from imaginaire.generators.craft_base import *
+
+
+
+class Generator(nn.Module):
+    def __init__(self,opt):
+        super(Generator, self).__init__()
+        gen_cfg = opt.arch.gen
+        data_cfg = opt.data
+        # self.gen_model = gen_model
+        self.gen_cfg = opt.arch.gen
+        if gen_cfg.transform_mode in ['project_RGB','volum_rendering','proj_like_radus']:
+            self.pano_direction = torch.from_numpy(geometry_transform.get_original_coord(opt)).unsqueeze(0).to(opt.device)
+        if gen_cfg.transform_mode == 'volum_rendering':
+            last_act = 'relu'
+        else:
+            last_act = 'softmax'
+        self.depth_model = inner_Generator(gen_cfg,gen_cfg.depth_arch,data_cfg,num_input_channels=3,last_act=last_act)
+        render_input_channel = 3
+        if gen_cfg.cat_opa:
+            render_input_channel = render_input_channel+1
+        self.denoise_model = inner_Generator(gen_cfg,gen_cfg.render_arch,data_cfg,render_input_channel,last_act='sigmoid')
+
+        self.PE = None
+
+
+
+    def forward(self, inputs, style_img=None,opt=None):
+        estimated_height = self.depth_model(inputs)
+
+        if self.gen_cfg.transform_mode in ['project_RGB','volum_rendering','proj_like_radus']:
+            geo_outputs = geometry_transform.render(opt,inputs,estimated_height,self.pano_direction,PE=self.PE)
+            generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth']
+            if 'voxel' in geo_outputs.keys():
+                voxel = geo_outputs['voxel']
+        # mu, logvar, z = self.style_encode(style_img)
+        # z = self.style_model(z)
+        if self.gen_cfg.cat_opa:
+            generator_inputs = torch.cat((generator_inputs,opacity),dim=1)
+        output_RGB = self.denoise_model(generator_inputs)
+        out_put = {
+            'pred': output_RGB,
+            # 'inter_RGB': generator_inputs,  ### out_feature not for show
+            # 'mu' :mu,
+            # 'logvar' : logvar,
+            }
+        if self.gen_cfg.transform_mode in ['volum_rendering']:
+            out_put['opacity'] = opacity
+        if self.gen_cfg.transform_mode:
+            out_put['estimated_height'] = estimated_height
+        out_put['generator_inputs'] = generator_inputs
+        out_put['voxel'] = voxel
+        out_put['depth'] = depth
+        return out_put
+
diff --git a/imaginaire/generators/craft_2stage_add_style.py b/imaginaire/generators/craft_2stage_add_style.py
new file mode 100644
index 0000000000000000000000000000000000000000..7920abdf110783ea2921e5f8faa98c3a7c284ec9
--- /dev/null
+++ b/imaginaire/generators/craft_2stage_add_style.py
@@ -0,0 +1,75 @@
+import torch
+from torch import nn
+import sys 
+sys.path.append(".") 
+from model import geometry_transform
+from imaginaire.generators.craft_base import *
+
+
+
+class Generator(nn.Module):
+    def __init__(self,opt):
+        super(Generator, self).__init__()
+        gen_cfg = opt.arch.gen
+        data_cfg = opt.data
+        style_enc_cfg = opt.arch.gen.style_enc_cfg
+        # self.gen_model = gen_model
+        self.style_inject = getattr(gen_cfg, 'style_inject',
+                                       None)
+        self.gen_cfg = opt.arch.gen
+        self.pano_direction = torch.from_numpy(geometry_transform.get_original_coord(opt)).unsqueeze(0).to(opt.device)
+        last_act = 'relu'
+        self.depth_model = inner_Generator_split(gen_cfg,gen_cfg.depth_arch,data_cfg,num_input_channels=3,last_act=last_act)
+
+
+        render_input_channel = 3
+        if gen_cfg.cat_opa:
+            render_input_channel +=1
+        if gen_cfg.cat_depth:
+            render_input_channel +=1
+
+        self.denoise_model = inner_Generator_split(gen_cfg,gen_cfg.render_arch,data_cfg,render_input_channel,last_act='sigmoid')
+        if self.style_inject:
+            if self.style_inject=='histo':
+                self.style_encode = histo_process(style_enc_cfg)
+            elif self.style_inject=='perspective':
+                self.style_encode = StyleEncoder(style_enc_cfg)
+            else:
+                raise Exception('Unknown style inject')
+            self.style_model = StyleMLP(style_dim=style_enc_cfg.style_dims, out_dim=style_enc_cfg.interm_style_dims, hidden_channels=style_enc_cfg.hidden_channel, leaky_relu=True, num_layers=5, normalize_input=True,
+                        output_act=True)
+
+        self.PE = geometry_transform.position_produce(opt) if gen_cfg.cat_PE else None
+
+
+
+    def forward(self, inputs, style_img=None,opt=None):
+        # predicted height of satellite images
+        estimated_height = self.depth_model(inputs)
+        geo_outputs = geometry_transform.render(opt,inputs,estimated_height,self.pano_direction,PE=self.PE)
+        generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth']
+        if 'voxel' in geo_outputs.keys():
+            voxel = geo_outputs['voxel']
+                
+        if self.gen_cfg.cat_opa:
+            generator_inputs = torch.cat((generator_inputs,opacity),dim=1)
+        if self.gen_cfg.cat_depth:
+            generator_inputs = torch.cat((generator_inputs,depth),dim=1)
+        if self.style_inject:
+            mu, logvar, z = self.style_encode(style_img)
+            z = self.style_model(z)
+        else:
+            z = None
+        # merge multiple sources(rgb,opacity,depth and sky) and denoise redundancy
+        output_RGB = self.denoise_model(generator_inputs,z)
+        out_put = {'pred': output_RGB}
+        if self.style_inject:
+            out_put['mu'] = mu
+            out_put['logvar']  = logvar
+        out_put['estimated_height'] = estimated_height
+        out_put['generator_inputs'] = generator_inputs
+        out_put['voxel'] = voxel
+        out_put['depth'] = depth
+        out_put['opacity'] = opacity
+        return out_put
+
diff --git a/imaginaire/generators/craft_base.py b/imaginaire/generators/craft_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f283c701269ad8eaa656036ade597d37393ebc
--- /dev/null
+++ b/imaginaire/generators/craft_base.py
@@ -0,0 +1,483 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import Upsample as NearestUpsample
+import torch.nn.functional as F
+from functools import partial
+
+import sys 
+sys.path.append(".") 
+from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
+
+
+class StyleMLP(nn.Module):
+    r"""MLP converting style code to intermediate style representation."""
+
+    def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True,
+                 output_act=True):
+        super(StyleMLP, self).__init__()
+
+        self.normalize_input = normalize_input
+        self.output_act = output_act
+        fc_layers = []
+        fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True))
+        for i in range(num_layers-1):
+            fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True))
+        self.fc_layers = nn.ModuleList(fc_layers)
+
+        self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True)
+
+        if leaky_relu:
+            self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        else:
+            self.act = partial(F.relu, inplace=True)
+
+    def forward(self, z):
+        r""" Forward network
+
+        Args:
+            z (N x style_dim tensor): Style codes.
+        """
+        if self.normalize_input:
+            z = F.normalize(z, p=2, dim=-1,eps=1e-6)
+        for fc_layer in self.fc_layers:
+            z = self.act(fc_layer(z))
+        z = self.fc_out(z)
+        if self.output_act:
+            z = self.act(z)
+        return z
+
+class histo_process(nn.Module):
+    r"""Histo process to replace Style Encoder constructor.
+
+    Args:
+        style_enc_cfg (obj): Style encoder definition file.
+    """
+    def __init__(self,style_enc_cfg):
+        super().__init__()
+        # if style_enc_cfg.histo.mode in ['RGB','rgb']:
+        input_channel=270
+        # else:
+            # input_channel=90
+        style_dims = style_enc_cfg.style_dims
+        self.no_vae = getattr(style_enc_cfg, 'no_vae', False)
+        num_filters = getattr(style_enc_cfg, 'num_filters', 180)
+        self.process_model = nn.ModuleList()
+        self.layer1 = LinearBlock(input_channel,num_filters)
+        self.layer4 = LinearBlock(num_filters, num_filters)
+        self.fc_mu = LinearBlock(num_filters, style_dims,nonlinearity='tanh')
+        if not self.no_vae:
+            self.fc_var = LinearBlock(num_filters, style_dims,nonlinearity='tanh')
+
+
+    def forward(self,histo):
+        x = self.layer1(histo)
+        x = self.layer4(x)
+        mu = self.fc_mu(x) #[-1,1]
+        if not self.no_vae:
+            logvar = self.fc_var(x) # [-1,1]
+            std = torch.exp(0.5 * logvar)  # [0.607,1.624]
+            eps = torch.randn_like(std) 
+            z = eps.mul(std) + mu
+        else:
+            z = mu
+            logvar = torch.zeros_like(mu)
+        return mu, logvar, z
+
+
+
+class StyleEncoder(nn.Module):
+    r"""Style Encoder constructor.
+
+    Args:
+        style_enc_cfg (obj): Style encoder definition file.
+    """
+
+    def __init__(self, style_enc_cfg):
+        super(StyleEncoder, self).__init__()
+        input_image_channels = style_enc_cfg.input_image_channels
+        num_filters = style_enc_cfg.num_filters
+        kernel_size = style_enc_cfg.kernel_size
+        padding = int(np.ceil((kernel_size - 1.0) / 2))
+        style_dims = style_enc_cfg.style_dims
+        weight_norm_type = style_enc_cfg.weight_norm_type
+        self.no_vae = getattr(style_enc_cfg, 'no_vae', False)
+        activation_norm_type = 'none'
+        nonlinearity = 'leakyrelu'
+        base_conv2d_block = \
+            partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              stride=2,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              # inplace_nonlinearity=True,
+                              nonlinearity=nonlinearity)
+        self.layer1 = base_conv2d_block(input_image_channels, num_filters)
+        self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
+        self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
+        self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
+        self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
+        self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
+        self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh')
+        if not self.no_vae:
+            self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh')
+
+    def forward(self, input_x):
+        r"""SPADE Style Encoder forward.
+
+        Args:
+            input_x (N x 3 x H x W tensor): input images.
+        Returns:
+            mu (N x C tensor): Mean vectors.
+            logvar (N x C tensor): Log-variance vectors.
+            z (N x C tensor): Style code vectors.
+        """
+        if input_x.size(2) != 256 or input_x.size(3) != 256:
+            input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
+        x = self.layer1(input_x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.layer5(x)
+        x = self.layer6(x)
+        x = x.view(x.size(0), -1)
+        mu = self.fc_mu(x)
+        if not self.no_vae:
+            logvar = self.fc_var(x)
+            std = torch.exp(0.5 * logvar)
+            eps = torch.randn_like(std)
+            z = eps.mul(std) + mu
+        else:
+            z = mu
+            logvar = torch.zeros_like(mu)
+        return mu, logvar, z
+
+
+class RenderCNN(nn.Module):
+    r"""CNN converting intermediate feature map to final image."""
+
+    def __init__(self, in_channels, style_dim, hidden_channels=256,
+                 leaky_relu=True):
+        super(RenderCNN, self).__init__()
+        self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels)
+
+        self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0)
+        self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
+        self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
+
+        self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
+        self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
+
+        self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
+        self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
+
+        self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0)
+
+        if leaky_relu:
+            self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        else:
+            self.act = partial(F.relu, inplace=True)
+
+    def modulate(self, x, w_, b_):
+        w_ = w_[..., None, None]
+        b_ = b_[..., None, None]
+        return x * (w_+1) + b_ +1e-9
+
+    def forward(self, x, z):
+        r"""Forward network.
+
+        Args:
+            x (N x in_channels x H x W tensor): Intermediate feature map
+            z (N x style_dim tensor): Style codes.
+        """
+        z = self.fc_z_cond(z)
+        adapt = torch.chunk(z, 2 * 2, dim=-1)
+        y = self.act(self.conv1(x))
+
+        y = y + self.conv2b(self.act(self.conv2a(y)))
+        y = self.act(self.modulate(y, adapt[0], adapt[1]))
+
+        y = y + self.conv3b(self.act(self.conv3a(y)))
+        y = self.act(self.modulate(y, adapt[2], adapt[3]))
+
+        y = y + self.conv4b(self.act(self.conv4a(y)))
+        y = self.act(y)
+
+        y = self.conv4(y)
+        y = torch.sigmoid(y)
+        return y
+
+
+class inner_Generator(nn.Module):
+    r"""Pix2pixHD coarse-to-fine generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+        last_act:  ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'.
+    """
+
+    def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'):
+        super().__init__()
+        assert last_act in ['none', 'relu', 'leakyrelu', 'prelu',
+            'tanh' , 'sigmoid' , 'softmax']
+        # pix2pixHD has a global generator.
+        global_gen_cfg = inner_cfg
+        # By default, pix2pixHD using instance normalization.
+        activation_norm_type = getattr(gen_cfg, 'activation_norm_type',
+                                       'instance')
+        activation_norm_params = getattr(gen_cfg, 'activation_norm_params',
+                                         None)
+        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '')
+        padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect')
+        base_conv_block = partial(Conv2dBlock,
+                                  padding_mode=padding_mode,
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type=activation_norm_type,
+                                  activation_norm_params=activation_norm_params,
+                                  nonlinearity='relu')
+        base_res_block = partial(Res2dBlock,
+                                 padding_mode=padding_mode,
+                                 weight_norm_type=weight_norm_type,
+                                 activation_norm_type=activation_norm_type,
+                                 activation_norm_params=activation_norm_params,
+                                 nonlinearity='relu', order='CNACN')
+        # Know what is the number of available segmentation labels.
+
+        # Global generator model.
+        global_model = GlobalGenerator(global_gen_cfg, data_cfg,
+                                       num_input_channels, padding_mode,
+                                       base_conv_block, base_res_block,last_act=last_act)
+        self.global_model = global_model
+
+
+    def forward(self, input, random_style=False):
+        r"""Coarse-to-fine generator forward.
+
+        Args:
+            data (dict) : Dictionary of input data.
+            random_style (bool): Always set to false for the pix2pixHD model.
+        Returns:
+            output (dict) : Dictionary of output data.
+        """
+        return self.global_model(input)
+
+
+
+    def load_pretrained_network(self, pretrained_dict):
+        r"""Load a pretrained network."""
+        # print(pretrained_dict.keys())
+        model_dict = self.state_dict()
+        print('Pretrained network has fewer layers; The following are '
+              'not initialized:')
+
+        not_initialized = set()
+        for k, v in model_dict.items():
+            kp = 'module.' + k.replace('global_model.', 'global_model.model.')
+            if kp in pretrained_dict and v.size() == pretrained_dict[kp].size():
+                model_dict[k] = pretrained_dict[kp]
+            else:
+                not_initialized.add('.'.join(k.split('.')[:2]))
+        print(sorted(not_initialized))
+        self.load_state_dict(model_dict)
+
+    def inference(self, data, **kwargs):
+        r"""Generator inference.
+
+        Args:
+            data (dict) : Dictionary of input data.
+        Returns:
+            fake_images (tensor): Output fake images.
+            file_names (str): Data file name.
+        """
+        output = self.forward(data, **kwargs)
+        return output['fake_images'], data['key']['seg_maps'][0]
+
+
+class GlobalGenerator(nn.Module):
+    r"""Coarse generator constructor. This is the main generator in the
+    pix2pixHD architecture.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+        num_input_channels (int): Number of segmentation labels.
+        padding_mode (str): zero | reflect | ...
+        base_conv_block (obj): Conv block with preset attributes.
+        base_res_block (obj): Residual block with preset attributes.
+        last_act (str, optional, default='relu'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+    """
+
+    def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode,
+                 base_conv_block, base_res_block,last_act='relu'):
+        super(GlobalGenerator, self).__init__()
+
+        # num_img_channels = get_paired_input_image_channel_number(data_cfg)
+        num_out_put_channels = getattr(gen_cfg, 'output_nc', 64)
+        num_filters = getattr(gen_cfg, 'num_filters', 64)
+        num_downsamples = getattr(gen_cfg, 'num_downsamples', 4)
+        num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9)
+        # First layer.
+        model = [base_conv_block(num_input_channels, num_filters,
+                                 kernel_size=7, padding=3)]
+        # Downsample.
+        for i in range(num_downsamples):
+            ch = num_filters * (2 ** i)
+            model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
+        # ResNet blocks.
+        ch = num_filters * (2 ** num_downsamples)
+        for i in range(num_res_blocks):
+            model += [base_res_block(ch, ch, 3, padding=1)]
+        # Upsample.
+        num_upsamples = num_downsamples
+        for i in reversed(range(num_upsamples)):
+            ch = num_filters * (2 ** i)
+            model += \
+                [NearestUpsample(scale_factor=2),
+                 base_conv_block(ch * 2, ch, 3, padding=1)]
+        model += [Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3,
+                              padding_mode=padding_mode, nonlinearity=last_act)]
+
+        self.model = nn.Sequential(*model)
+
+    def forward(self, input):
+        r"""Coarse-to-fine generator forward.
+
+        Args:
+            input (4D tensor) : Input semantic representations.
+        Returns:
+            output (4D tensor) : Synthesized image by generator.
+        """
+        return self.model(input)
+
+class inner_Generator_split(nn.Module):
+    r"""Pix2pixHD coarse-to-fine generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+        last_act:  ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'.
+    """
+
+    def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'):
+        super().__init__()
+        assert last_act in ['none', 'relu', 'leakyrelu', 'prelu',
+            'tanh' , 'sigmoid' , 'softmax']
+        # pix2pixHD has a global generator.
+        # By default, pix2pixHD using instance normalization.
+        print(inner_cfg)
+        style_dim =  gen_cfg.style_enc_cfg.interm_style_dims
+        activation_norm_type = getattr(gen_cfg, 'activation_norm_type',
+                                       'instance')
+        activation_norm_params = getattr(gen_cfg, 'activation_norm_params',
+                                         None)
+        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '')
+        padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect')
+        # num_input_channels = get_paired_input_label_channel_number(data_cfg)
+        # num_input_channels = 3
+        base_conv_block = partial(Conv2dBlock,
+                                  padding_mode=padding_mode,
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type=activation_norm_type,
+                                  activation_norm_params=activation_norm_params,
+                                )
+        base_res_block = partial(Res2dBlock,
+                                 padding_mode=padding_mode,
+                                 weight_norm_type=weight_norm_type,
+                                 activation_norm_type=activation_norm_type,
+                                 activation_norm_params=activation_norm_params,
+                                 nonlinearity='relu', order='CNACN')
+        # Know what is the number of available segmentation labels.
+
+        # Global generator model.
+
+        num_out_put_channels = getattr(inner_cfg, 'output_nc', 64)
+        num_filters = getattr(inner_cfg, 'num_filters', 64)
+        num_downsamples = 4
+        num_res_blocks = getattr(inner_cfg, 'num_res_blocks', 9)
+        # First layer.
+        model = [base_conv_block(num_input_channels, num_filters,
+                                 kernel_size=7, padding=3)]
+        model += [nn.PReLU()]
+        # Downsample.
+        for i in range(num_downsamples):
+            ch = num_filters * (2 ** i)
+            model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
+            model += [nn.PReLU()]
+        # ResNet blocks.
+        ch = num_filters * (2 ** num_downsamples)
+        for i in range(num_res_blocks):
+            model += [base_res_block(ch, ch, 3, padding=1)]
+        self.model = nn.Sequential(*model)
+        # Upsample.
+        assert num_downsamples == 4
+        if not (inner_cfg.name =='render' and gen_cfg.style_inject):
+            list = [16,8,4,2]
+        else:
+            list = [16,6,6,6]
+
+        self.up0_a = NearestUpsample(scale_factor=2)
+        self.up0_b = base_conv_block(num_filters * list[0], num_filters*list[1], 3, padding=1)
+        self.up1_a = NearestUpsample(scale_factor=2)
+        self.up1_b = base_conv_block(num_filters * list[1], num_filters*list[2], 3, padding=1)
+        self.up2_a = NearestUpsample(scale_factor=2)
+        self.up2_b = base_conv_block(num_filters * list[2], num_filters*list[3], 3, padding=1)
+        self.up3_a = NearestUpsample(scale_factor=2)
+        self.up3_b = base_conv_block(num_filters * list[3], num_filters, 3, padding=1)
+        self.up_end = Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3,
+                              padding_mode=padding_mode, nonlinearity=last_act)
+        if inner_cfg.name =='render' and gen_cfg.style_inject:
+            self.fc_z_cond = nn.Linear(style_dim, 4* list[-1] * num_filters)
+
+    def modulate(self, x, w, b):
+        w = w[..., None, None]
+        b = b[..., None, None]
+        return x * (w+1) + b
+
+    def forward(self, input,z=None):
+        r"""Coarse-to-fine generator forward.
+
+        Args:
+            input (4D tensor) : Input semantic representations.
+        Returns:
+            output (4D tensor) : Synthesized image by generator.
+        """
+        if z is not None:
+            z = self.fc_z_cond(z)
+            adapt = torch.chunk(z, 2 * 2, dim=-1)
+        input = self.model(input)
+        input = self.up0_a(input)
+        input = self.up0_b(input)
+        input = F.leaky_relu(input,negative_slope=0.2, inplace=True)
+        input = self.up1_a(input)
+        input = self.up1_b(input)
+        if z is not None:
+            input = self.modulate(input, adapt[0], adapt[1])
+        input = F.leaky_relu(input,negative_slope=0.2, inplace=True)
+
+        input = self.up2_a(input)
+        input = self.up2_b(input)
+        if z is not None:
+            input = self.modulate(input, adapt[2], adapt[3])
+        input = F.leaky_relu(input,negative_slope=0.2, inplace=True)
+
+        input = self.up3_a(input)
+        input = self.up3_b(input)
+        input = F.leaky_relu(input,negative_slope=0.2, inplace=True)
+
+        input = self.up_end(input)
+
+        return input
+
+if __name__=='__main__':
+    from easydict import EasyDict as edict
+    style_enc_cfg = edict()
+    style_enc_cfg.histo.mode = 'RGB'
+    style_enc_cfg.histo.num_filters = 180
+    model = histo_process(style_enc_cfg)
\ No newline at end of file
diff --git a/imaginaire/generators/dummy.py b/imaginaire/generators/dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9a2f1edec286be8751d6a188ebc4f47875c437
--- /dev/null
+++ b/imaginaire/generators/dummy.py
@@ -0,0 +1,29 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.nn as nn
+
+from imaginaire.layers import LinearBlock
+
+
+class Generator(nn.Module):
+    r"""Dummy generator.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super(Generator, self).__init__()
+        self.dummy_layer = LinearBlock(1, 1)
+        pass
+
+    def forward(self, data):
+        r"""Dummy Generator forward.
+
+        Args:
+            data (dict):
+        """
+        return
diff --git a/imaginaire/generators/fs_vid2vid.py b/imaginaire/generators/fs_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c2c20048d0e371d9393302cc858fe738a4d53b
--- /dev/null
+++ b/imaginaire/generators/fs_vid2vid.py
@@ -0,0 +1,1176 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import copy
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.layers import (Conv2dBlock, HyperConv2dBlock, HyperRes2dBlock,
+                               LinearBlock, Res2dBlock)
+from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels,
+                                               pick_image, resample)
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+                                   get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.init_weight import weights_init
+from imaginaire.utils.misc import get_and_setattr, get_nested_attr
+
+
+class Generator(nn.Module):
+    r"""Few-shot vid2vid generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super().__init__()
+        self.gen_cfg = gen_cfg
+        self.data_cfg = data_cfg
+        self.num_frames_G = data_cfg.num_frames_G
+        self.flow_cfg = flow_cfg = gen_cfg.flow
+
+        # For pose dataset.
+        self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
+        if self.is_pose_data:
+            pose_cfg = data_cfg.for_pose_dataset
+            self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
+            self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
+                                              False)
+
+        num_img_channels = get_paired_input_image_channel_number(data_cfg)
+        self.num_downsamples = num_downsamples = \
+            get_and_setattr(gen_cfg, 'num_downsamples', 5)
+        conv_kernel_size = get_and_setattr(gen_cfg, 'kernel_size', 3)
+        num_filters = get_and_setattr(gen_cfg, 'num_filters', 32)
+
+        max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024)
+        self.max_num_filters = gen_cfg.max_num_filters = \
+            min(max_num_filters, num_filters * (2 ** num_downsamples))
+        # Get number of filters at each layer in the main branch.
+        num_filters_each_layer = [min(self.max_num_filters,
+                                      num_filters * (2 ** i))
+                                  for i in range(num_downsamples + 2)]
+
+        # Hyper normalization / convolution.
+        hyper_cfg = gen_cfg.hyper
+        # Use adaptive weight generation for SPADE.
+        self.use_hyper_spade = hyper_cfg.is_hyper_spade
+        # Use adaptive for convolutional layers in the main branch.
+        self.use_hyper_conv = hyper_cfg.is_hyper_conv
+        # Number of hyper layers.
+        self.num_hyper_layers = getattr(hyper_cfg, 'num_hyper_layers', 4)
+        if self.num_hyper_layers == -1:
+            self.num_hyper_layers = num_downsamples
+        gen_cfg.hyper.num_hyper_layers = self.num_hyper_layers
+        # Network weight generator.
+        self.weight_generator = WeightGenerator(gen_cfg, data_cfg)
+
+        # Number of layers to perform multi-spade combine.
+        self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
+                                              'num_layers', 3)
+        # Whether to generate raw output for additional losses.
+        self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output',
+                                           False)
+
+        # Main branch image generation.
+        padding = conv_kernel_size // 2
+        activation_norm_type = get_and_setattr(gen_cfg, 'activation_norm_type',
+                                               'sync_batch')
+        weight_norm_type = get_and_setattr(gen_cfg, 'weight_norm_type',
+                                           'spectral')
+        activation_norm_params = get_and_setattr(gen_cfg,
+                                                 'activation_norm_params',
+                                                 None)
+        spade_in_channels = []  # Input channel size in SPADE module.
+        for i in range(num_downsamples + 1):
+            spade_in_channels += [[num_filters_each_layer[i]]] \
+                if i >= self.num_multi_spade_layers \
+                else [[num_filters_each_layer[i]] * 3]
+
+        order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC')
+        for i in reversed(range(num_downsamples + 1)):
+            activation_norm_params.cond_dims = spade_in_channels[i]
+            is_hyper_conv = self.use_hyper_conv and i < self.num_hyper_layers
+            is_hyper_norm = self.use_hyper_spade and i < self.num_hyper_layers
+            setattr(self, 'up_%d' % i, HyperRes2dBlock(
+                num_filters_each_layer[i + 1], num_filters_each_layer[i],
+                conv_kernel_size, padding=padding,
+                weight_norm_type=weight_norm_type,
+                activation_norm_type=activation_norm_type,
+                activation_norm_params=activation_norm_params,
+                order=order * 2,
+                is_hyper_conv=is_hyper_conv, is_hyper_norm=is_hyper_norm))
+
+        self.conv_img = Conv2dBlock(num_filters, num_img_channels,
+                                    conv_kernel_size, padding=padding,
+                                    nonlinearity='leakyrelu', order='AC')
+        self.upsample = partial(F.interpolate, scale_factor=2)
+
+        # Flow estimation module.
+        # Whether to warp reference image and combine with the synthesized.
+        self.warp_ref = getattr(flow_cfg, 'warp_ref', True)
+        if self.warp_ref:
+            self.flow_network_ref = FlowGenerator(flow_cfg, data_cfg, 2)
+            self.ref_image_embedding = \
+                LabelEmbedder(flow_cfg.multi_spade_combine.embed,
+                              num_img_channels + 1)
+        # At beginning of training, only train an image generator.
+        self.temporal_initialized = False
+        if getattr(gen_cfg, 'init_temporal', True):
+            self.init_temporal_network()
+
+    def forward(self, data):
+        r"""few-shot vid2vid generator forward.
+
+        Args:
+            data (dict) : Dictionary of input data.
+        Returns:
+            output (dict) : Dictionary of output data.
+        """
+        label = data['label']
+        ref_labels, ref_images = data['ref_labels'], data['ref_images']
+        prev_labels, prev_images = data['prev_labels'], data['prev_images']
+        is_first_frame = prev_labels is None
+
+        if self.is_pose_data:
+            label, prev_labels = extract_valid_pose_labels(
+                [label, prev_labels], self.pose_type, self.remove_face_labels)
+            ref_labels = extract_valid_pose_labels(
+                ref_labels, self.pose_type, self.remove_face_labels,
+                do_remove=False)
+
+        # Weight generation.
+        x, encoded_label, conv_weights, norm_weights, atn, atn_vis, ref_idx = \
+            self.weight_generator(ref_images, ref_labels, label, is_first_frame)
+
+        # Flow estimation.
+        flow, flow_mask, img_warp, cond_inputs = \
+            self.flow_generation(label, ref_labels, ref_images,
+                                 prev_labels, prev_images, ref_idx)
+
+        for i in range(len(encoded_label)):
+            encoded_label[i] = [encoded_label[i]]
+        if self.generate_raw_output:
+            encoded_label_raw = [encoded_label[i] for i in
+                                 range(self.num_multi_spade_layers)]
+            x_raw = None
+        encoded_label = self.SPADE_combine(encoded_label, cond_inputs)
+
+        # Main branch image generation.
+        for i in range(self.num_downsamples, -1, -1):
+            conv_weight = norm_weight = [None] * 3
+            if self.use_hyper_conv and i < self.num_hyper_layers:
+                conv_weight = conv_weights[i]
+            if self.use_hyper_spade and i < self.num_hyper_layers:
+                norm_weight = norm_weights[i]
+
+            # Main branch residual blocks.
+            x = self.one_up_conv_layer(x, encoded_label,
+                                       conv_weight, norm_weight, i)
+
+            # For raw output generation.
+            if self.generate_raw_output and i < self.num_multi_spade_layers:
+                x_raw = self.one_up_conv_layer(x_raw, encoded_label_raw,
+                                               conv_weight, norm_weight, i)
+            else:
+                x_raw = x
+
+        # Final conv layer.
+        if self.generate_raw_output:
+            img_raw = torch.tanh(self.conv_img(x_raw))
+        else:
+            img_raw = None
+        img_final = torch.tanh(self.conv_img(x))
+
+        output = dict()
+        output['fake_images'] = img_final
+        output['fake_flow_maps'] = flow
+        output['fake_occlusion_masks'] = flow_mask
+        output['fake_raw_images'] = img_raw
+        output['warped_images'] = img_warp
+        output['attention_visualization'] = atn_vis
+        output['ref_idx'] = ref_idx
+        return output
+
+    def one_up_conv_layer(self, x, encoded_label, conv_weight, norm_weight, i):
+        r"""One residual block layer in the main branch.
+
+        Args:
+            x (4D tensor) : Current feature map.
+            encoded_label (list of tensors) : Encoded input label maps.
+            conv_weight (list of tensors) : Hyper conv weights.
+            norm_weight (list of tensors) : Hyper norm weights.
+            i (int) : Layer index.
+        Returns:
+            x (4D tensor) : Output feature map.
+        """
+        layer = getattr(self, 'up_' + str(i))
+        x = layer(x, *encoded_label[i], conv_weights=conv_weight,
+                  norm_weights=norm_weight)
+        if i != 0:
+            x = self.upsample(x)
+        return x
+
+    def init_temporal_network(self, cfg_init=None):
+        r"""When starting training multiple frames, initialize the flow network.
+
+        Args:
+            cfg_init (dict) : Weight initialization config.
+        """
+        flow_cfg = self.flow_cfg
+        emb_cfg = flow_cfg.multi_spade_combine.embed
+        num_frames_G = self.num_frames_G
+        self.temporal_initialized = True
+
+        self.sep_prev_flownet = flow_cfg.sep_prev_flow or (num_frames_G != 2) \
+            or not flow_cfg.warp_ref
+        if self.sep_prev_flownet:
+            self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg,
+                                                   num_frames_G)
+            if cfg_init is not None:
+                self.flow_network_temp.apply(weights_init(cfg_init.type,
+                                                          cfg_init.gain))
+        else:
+            self.flow_network_temp = self.flow_network_ref
+
+        self.sep_prev_embedding = emb_cfg.sep_warp_embed or \
+            not flow_cfg.warp_ref
+        if self.sep_prev_embedding:
+            num_img_channels = get_paired_input_image_channel_number(
+                self.data_cfg)
+            self.prev_image_embedding = \
+                LabelEmbedder(emb_cfg, num_img_channels + 1)
+            if cfg_init is not None:
+                self.prev_image_embedding.apply(
+                    weights_init(cfg_init.type, cfg_init.gain))
+        else:
+            self.prev_image_embedding = self.ref_image_embedding
+
+        if self.warp_ref:
+            if self.sep_prev_flownet:
+                self.init_network_weights(self.flow_network_ref,
+                                          self.flow_network_temp)
+                print('Initialized temporal flow network with the reference '
+                      'one.')
+            if self.sep_prev_embedding:
+                self.init_network_weights(self.ref_image_embedding,
+                                          self.prev_image_embedding)
+                print('Initialized temporal embedding network with the '
+                      'reference one.')
+            self.flow_temp_is_initalized = True
+
+    def init_network_weights(self, net_src, net_dst):
+        r"""Initialize weights in net_dst with those in net_src."""
+        source_weights = net_src.state_dict()
+        target_weights = net_dst.state_dict()
+
+        for k, v in source_weights.items():
+            if k in target_weights and target_weights[k].size() == v.size():
+                target_weights[k] = v
+        net_dst.load_state_dict(target_weights)
+
+    def load_pretrained_network(self, pretrained_dict, prefix='module.'):
+        r"""Load the pretrained network into self network.
+
+        Args:
+            pretrained_dict (dict): Pretrained network weights.
+            prefix (str): Prefix to the network weights name.
+        """
+        # print(pretrained_dict.keys())
+        model_dict = self.state_dict()
+        print('Pretrained network has fewer layers; The following are '
+              'not initialized:')
+
+        not_initialized = set()
+        for k, v in model_dict.items():
+            kp = prefix + k
+            if kp in pretrained_dict and v.size() == pretrained_dict[kp].size():
+                model_dict[k] = pretrained_dict[kp]
+            else:
+                not_initialized.add('.'.join(k.split('.')[:2]))
+        print(sorted(not_initialized))
+        self.load_state_dict(model_dict)
+
+    def reset(self):
+        r"""Reset the network at the beginning of a sequence."""
+        self.weight_generator.reset()
+
+    def flow_generation(self, label, ref_labels, ref_images, prev_labels,
+                        prev_images, ref_idx):
+        r"""Generates flows and masks for warping reference / previous images.
+
+        Args:
+            label (NxCxHxW tensor): Target label map.
+            ref_labels (NxKxCxHxW tensor): Reference label maps.
+            ref_images (NxKx3xHxW tensor): Reference images.
+            prev_labels (NxTxCxHxW tensor): Previous label maps.
+            prev_images (NxTx3xHxW tensor): Previous images.
+            ref_idx (Nx1 tensor): Index for which image to use from the
+            reference images.
+        Returns:
+            (tuple):
+              - flow (list of Nx2xHxW tensor): Optical flows.
+              - occ_mask (list of Nx1xHxW tensor): Occlusion masks.
+              - img_warp (list of Nx3xHxW tensor): Warped reference / previous
+                images.
+              - cond_inputs (list of Nx4xHxW tensor): Conditional inputs for
+                SPADE combination.
+        """
+        # Pick an image in the reference images using ref_idx.
+        ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx)
+        # Only start using prev frames when enough prev frames are generated.
+        has_prev = prev_labels is not None and \
+            prev_labels.shape[1] == (self.num_frames_G - 1)
+        flow, occ_mask, img_warp, cond_inputs = [None] * 2, [None] * 2, \
+                                                [None] * 2, [None] * 2
+        if self.warp_ref:
+            # Generate flows/masks for warping the reference image.
+            flow_ref, occ_mask_ref = \
+                self.flow_network_ref(label, ref_label, ref_image)
+            ref_image_warp = resample(ref_image, flow_ref)
+            flow[0], occ_mask[0], img_warp[0] = \
+                flow_ref, occ_mask_ref, ref_image_warp[:, :3]
+            # Concat warped image and occlusion mask to form the conditional
+            # input.
+            cond_inputs[0] = torch.cat([img_warp[0], occ_mask[0]], dim=1)
+
+        if self.temporal_initialized and has_prev:
+            # Generate flows/masks for warping the previous image.
+            b, t, c, h, w = prev_labels.shape
+            prev_labels_concat = prev_labels.view(b, -1, h, w)
+            prev_images_concat = prev_images.view(b, -1, h, w)
+            flow_prev, occ_mask_prev = \
+                self.flow_network_temp(label, prev_labels_concat,
+                                       prev_images_concat)
+            img_prev_warp = resample(prev_images[:, -1], flow_prev)
+            flow[1], occ_mask[1], img_warp[1] = \
+                flow_prev, occ_mask_prev, img_prev_warp
+            cond_inputs[1] = torch.cat([img_warp[1], occ_mask[1]], dim=1)
+
+        return flow, occ_mask, img_warp, cond_inputs
+
+    def SPADE_combine(self, encoded_label, cond_inputs):
+        r"""Using Multi-SPADE to combine raw synthesized image with warped
+        images.
+
+        Args:
+            encoded_label (list of tensors): Original label map embeddings.
+            cond_inputs (list of tensors): New SPADE conditional inputs from the
+                warped images.
+        Returns:
+            encoded_label (list of tensors): Combined conditional inputs.
+        """
+        # Generate the conditional embeddings from inputs.
+        embedded_img_feat = [None, None]
+        if cond_inputs[0] is not None:
+            embedded_img_feat[0] = self.ref_image_embedding(cond_inputs[0])
+        if cond_inputs[1] is not None:
+            embedded_img_feat[1] = self.prev_image_embedding(cond_inputs[1])
+
+        # Combine the original encoded label maps with new conditional
+        # embeddings.
+        for i in range(self.num_multi_spade_layers):
+            encoded_label[i] += [w[i] if w is not None else None
+                                 for w in embedded_img_feat]
+        return encoded_label
+
+    def custom_init(self):
+        r"""This function is for dealing with the numerical issue that might
+        occur when doing mixed precision training.
+        """
+        print('Use custom initialization for the generator.')
+        for k, m in self.named_modules():
+            if 'weight_generator.ref_label_' in k and 'norm' in k:
+                m.eps = 1e-1
+
+
+class WeightGenerator(nn.Module):
+    r"""Weight generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super().__init__()
+        self.data_cfg = data_cfg
+        self.embed_cfg = embed_cfg = gen_cfg.embed
+        self.embed_arch = embed_cfg.arch
+
+        num_filters = gen_cfg.num_filters
+        self.max_num_filters = gen_cfg.max_num_filters
+        self.num_downsamples = num_downsamples = gen_cfg.num_downsamples
+        self.num_filters_each_layer = num_filters_each_layer = \
+            [min(self.max_num_filters, num_filters * (2 ** i))
+             for i in range(num_downsamples + 2)]
+        if getattr(embed_cfg, 'num_filters', 32) != num_filters:
+            raise ValueError('Embedding network must have the same number of '
+                             'filters as generator.')
+
+        # Normalization params.
+        hyper_cfg = gen_cfg.hyper
+        kernel_size = getattr(hyper_cfg, 'kernel_size', 3)
+        activation_norm_type = getattr(hyper_cfg, 'activation_norm_type',
+                                       'sync_batch')
+        weight_norm_type = getattr(hyper_cfg, 'weight_norm_type', 'spectral')
+        # Conv kernel size in main branch.
+        self.conv_kernel_size = conv_kernel_size = gen_cfg.kernel_size
+        # Conv kernel size in embedding network.
+        self.embed_kernel_size = embed_kernel_size = \
+            getattr(gen_cfg.embed, 'kernel_size', 3)
+        # Conv kernel size in SPADE.
+        self.kernel_size = kernel_size = \
+            getattr(gen_cfg.activation_norm_params, 'kernel_size', 1)
+        # Input channel size in SPADE module.
+        self.spade_in_channels = []
+        for i in range(num_downsamples + 1):
+            self.spade_in_channels += [num_filters_each_layer[i]]
+
+        # Hyper normalization / convolution.
+        # Use adaptive weight generation for SPADE.
+        self.use_hyper_spade = hyper_cfg.is_hyper_spade
+        # Use adaptive for the label embedding network.
+        self.use_hyper_embed = hyper_cfg.is_hyper_embed
+        # Use adaptive for convolutional layers in the main branch.
+        self.use_hyper_conv = hyper_cfg.is_hyper_conv
+        # Number of hyper layers.
+        self.num_hyper_layers = hyper_cfg.num_hyper_layers
+        # Order of operations in the conv block.
+        order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC')
+        self.conv_before_norm = order.find('C') < order.find('N')
+
+        # For reference image encoding.
+        # How to utilize the reference label map: concat | mul.
+        self.concat_ref_label = 'concat' in hyper_cfg.method_to_use_ref_labels
+        self.mul_ref_label = 'mul' in hyper_cfg.method_to_use_ref_labels
+        # Output spatial size for adaptive pooling layer.
+        self.sh_fix = self.sw_fix = 32
+        # Number of fc layers in weight generation.
+        self.num_fc_layers = getattr(hyper_cfg, 'num_fc_layers', 2)
+
+        # Reference image encoding network.
+        num_input_channels = get_paired_input_label_channel_number(data_cfg)
+        if num_input_channels == 0:
+            num_input_channels = getattr(data_cfg, 'label_channels', 1)
+        elif get_nested_attr(data_cfg, 'for_pose_dataset.pose_type',
+                             'both') == 'open':
+            num_input_channels -= 3
+        data_cfg.num_input_channels = num_input_channels
+        num_img_channels = get_paired_input_image_channel_number(data_cfg)
+        num_ref_channels = num_img_channels + (num_input_channels
+                                               if self.concat_ref_label else 0)
+        conv_2d_block = partial(
+            Conv2dBlock, kernel_size=kernel_size,
+            padding=(kernel_size // 2), weight_norm_type=weight_norm_type,
+            activation_norm_type=activation_norm_type,
+            nonlinearity='leakyrelu')
+
+        self.ref_img_first = conv_2d_block(num_ref_channels, num_filters)
+        if self.mul_ref_label:
+            self.ref_label_first = conv_2d_block(num_input_channels,
+                                                 num_filters)
+
+        for i in range(num_downsamples):
+            in_ch, out_ch = num_filters_each_layer[i], \
+                num_filters_each_layer[i + 1]
+            setattr(self, 'ref_img_down_%d' % i,
+                    conv_2d_block(in_ch, out_ch, stride=2))
+            setattr(self, 'ref_img_up_%d' % i, conv_2d_block(out_ch, in_ch))
+            if self.mul_ref_label:
+                setattr(self, 'ref_label_down_%d' % i,
+                        conv_2d_block(in_ch, out_ch, stride=2))
+                setattr(self, 'ref_label_up_%d' % i,
+                        conv_2d_block(out_ch, in_ch))
+
+        # Normalization / main branch conv weight generation.
+        if self.use_hyper_spade or self.use_hyper_conv:
+            for i in range(self.num_hyper_layers):
+                ch_in, ch_out = num_filters_each_layer[i], \
+                    num_filters_each_layer[i + 1]
+                conv_ks2 = conv_kernel_size ** 2
+                embed_ks2 = embed_kernel_size ** 2
+                spade_ks2 = kernel_size ** 2
+                spade_in_ch = self.spade_in_channels[i]
+
+                fc_names, fc_ins, fc_outs = [], [], []
+                if self.use_hyper_spade:
+                    fc0_out = fcs_out = (spade_in_ch * spade_ks2 + 1) * (
+                        1 if self.conv_before_norm else 2)
+                    fc1_out = (spade_in_ch * spade_ks2 + 1) * (
+                        1 if ch_in != ch_out else 2)
+                    fc_names += ['fc_spade_0', 'fc_spade_1', 'fc_spade_s']
+                    fc_ins += [ch_out] * 3
+                    fc_outs += [fc0_out, fc1_out, fcs_out]
+                    if self.use_hyper_embed:
+                        fc_names += ['fc_spade_e']
+                        fc_ins += [ch_out]
+                        fc_outs += [ch_in * embed_ks2 + 1]
+                if self.use_hyper_conv:
+                    fc0_out = ch_out * conv_ks2 + 1
+                    fc1_out = ch_in * conv_ks2 + 1
+                    fcs_out = ch_out + 1
+                    fc_names += ['fc_conv_0', 'fc_conv_1', 'fc_conv_s']
+                    fc_ins += [ch_in] * 3
+                    fc_outs += [fc0_out, fc1_out, fcs_out]
+
+                linear_block = partial(LinearBlock,
+                                       weight_norm_type='spectral',
+                                       nonlinearity='leakyrelu')
+                for n, l in enumerate(fc_names):
+                    fc_in = fc_ins[n] if self.mul_ref_label \
+                        else self.sh_fix * self.sw_fix
+                    fc_layer = [linear_block(fc_in, ch_out)]
+                    for k in range(1, self.num_fc_layers):
+                        fc_layer += [linear_block(ch_out, ch_out)]
+                    fc_layer += [LinearBlock(ch_out, fc_outs[n],
+                                             weight_norm_type='spectral')]
+                    setattr(self, '%s_%d' % (l, i), nn.Sequential(*fc_layer))
+
+        # Label embedding network.
+        num_hyper_layers = self.num_hyper_layers if self.use_hyper_embed else 0
+        self.label_embedding = LabelEmbedder(self.embed_cfg,
+                                             num_input_channels,
+                                             num_hyper_layers=num_hyper_layers)
+
+        # For multiple reference images.
+        if hasattr(hyper_cfg, 'attention'):
+            self.num_downsample_atn = get_and_setattr(hyper_cfg.attention,
+                                                      'num_downsamples', 2)
+            if data_cfg.initial_few_shot_K > 1:
+                self.attention_module = AttentionModule(hyper_cfg, data_cfg,
+                                                        conv_2d_block,
+                                                        num_filters_each_layer)
+        else:
+            self.num_downsample_atn = 0
+
+    def forward(self, ref_image, ref_label, label, is_first_frame):
+        r"""Generate network weights based on the reference images.
+
+        Args:
+            ref_image (NxKx3xHxW tensor): Reference images.
+            ref_label (NxKxCxHxW tensor): Reference labels.
+            label (NxCxHxW tensor): Target label.
+            is_first_frame (bool): Whether the current frame is the first frame.
+
+        Returns:
+            (tuple):
+              - x (NxC2xH2xW2 tensor): Encoded features from reference images
+                for the main branch (as input to the decoder).
+              - encoded_label (list of tensors): Encoded target label map for
+                SPADE.
+              - conv_weights (list of tensors): Network weights for conv
+                layers in the main network.
+              - norm_weights (list of tensors): Network weights for SPADE
+                layers in the main network.
+              - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps.
+              - atn_vis (1x1xH1xW1 tensor): Visualization for attention
+                scores.
+              - ref_idx (Nx1 tensor): Index for which image to use from the
+                reference images.
+        """
+        b, k, c, h, w = ref_image.size()
+        ref_image = ref_image.view(b * k, -1, h, w)
+        if ref_label is not None:
+            ref_label = ref_label.view(b * k, -1, h, w)
+
+        # Encode the reference images to get the features.
+        x, encoded_ref, atn, atn_vis, ref_idx = \
+            self.encode_reference(ref_image, ref_label, label, k)
+
+        # If the reference image has changed, recompute the network weights.
+        if self.training or is_first_frame or k > 1:
+            embedding_weights, norm_weights, conv_weights = [], [], []
+            for i in range(self.num_hyper_layers):
+                if self.use_hyper_spade:
+                    feat = encoded_ref[min(len(encoded_ref) - 1, i + 1)]
+                    embedding_weight, norm_weight = \
+                        self.get_norm_weights(feat, i)
+                    embedding_weights.append(embedding_weight)
+                    norm_weights.append(norm_weight)
+                if self.use_hyper_conv:
+                    feat = encoded_ref[min(len(encoded_ref) - 1, i)]
+                    conv_weights.append(self.get_conv_weights(feat, i))
+
+            if not self.training:
+                self.embedding_weights, self.conv_weights, self.norm_weights \
+                    = embedding_weights, conv_weights, norm_weights
+        else:
+            # print('Reusing network weights.')
+            embedding_weights, conv_weights, norm_weights \
+                = self.embedding_weights, self.conv_weights, self.norm_weights
+
+        # Encode the target label to get the encoded features.
+        encoded_label = self.label_embedding(label, weights=(
+            embedding_weights if self.use_hyper_embed else None))
+
+        return x, encoded_label, conv_weights, norm_weights, \
+            atn, atn_vis, ref_idx
+
+    def encode_reference(self, ref_image, ref_label, label, k):
+        r"""Encode the reference image to get features for weight generation.
+
+        Args:
+            ref_image ((NxK)x3xHxW tensor): Reference images.
+            ref_label ((NxK)xCxHxW tensor): Reference labels.
+            label (NxCxHxW tensor): Target label.
+            k (int): Number of reference images.
+        Returns:
+            (tuple):
+              - x (NxC2xH2xW2 tensor): Encoded features from reference images
+                for the main branch (as input to the decoder).
+              - encoded_ref (list of tensors): Encoded features from reference
+                images for the weight generation branch.
+              - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps.
+              - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores.
+              - ref_idx (Nx1 tensor): Index for which image to use from the
+                reference images.
+        """
+        if self.concat_ref_label:
+            # Concat reference label map and image together for encoding.
+            concat_ref = torch.cat([ref_image, ref_label], dim=1)
+            x = self.ref_img_first(concat_ref)
+        elif self.mul_ref_label:
+            # Apply conv to both reference label and image, then multiply them
+            # together for encoding.
+            x = self.ref_img_first(ref_image)
+            x_label = self.ref_label_first(ref_label)
+        else:
+            x = self.ref_img_first(ref_image)
+
+        # Attention map and the index of the most similar reference image.
+        atn = atn_vis = ref_idx = None
+        for i in range(self.num_downsamples):
+            x = getattr(self, 'ref_img_down_' + str(i))(x)
+            if self.mul_ref_label:
+                x_label = getattr(self, 'ref_label_down_' + str(i))(x_label)
+
+            # Combine different reference images at a particular layer.
+            if k > 1 and i == self.num_downsample_atn - 1:
+                x, atn, atn_vis = self.attention_module(x, label, ref_label)
+                if self.mul_ref_label:
+                    x_label, _, _ = self.attention_module(x_label, None, None,
+                                                          atn)
+
+                atn_sum = atn.view(label.shape[0], k, -1).sum(2)
+                ref_idx = torch.argmax(atn_sum, dim=1)
+
+        # Get all corresponding layers in the encoder output for generating
+        # weights in corresponding layers.
+        encoded_image_ref = [x]
+        if self.mul_ref_label:
+            encoded_ref_label = [x_label]
+
+        for i in reversed(range(self.num_downsamples)):
+            conv = getattr(self, 'ref_img_up_' + str(i))(
+                encoded_image_ref[-1])
+            encoded_image_ref.append(conv)
+            if self.mul_ref_label:
+                conv_label = getattr(self, 'ref_label_up_' + str(i))(
+                    encoded_ref_label[-1])
+                encoded_ref_label.append(conv_label)
+
+        if self.mul_ref_label:
+            encoded_ref = []
+            for i in range(len(encoded_image_ref)):
+                conv, conv_label \
+                    = encoded_image_ref[i], encoded_ref_label[i]
+                b, c, h, w = conv.size()
+                conv_label = nn.Softmax(dim=1)(conv_label)
+                conv_prod = (conv.view(b, c, 1, h * w) *
+                             conv_label.view(b, 1, c,
+                                             h * w)).sum(3, keepdim=True)
+                encoded_ref.append(conv_prod)
+        else:
+            encoded_ref = encoded_image_ref
+        encoded_ref = encoded_ref[::-1]
+
+        return x, encoded_ref, atn, atn_vis, ref_idx
+
+    def get_norm_weights(self, x, i):
+        r"""Adaptively generate weights for SPADE in layer i of generator.
+
+        Args:
+            x (NxCxHxW tensor): Input features.
+            i (int): Layer index.
+        Returns:
+            (tuple):
+              - embedding_weights (list of tensors): Weights for the label
+                embedding network.
+              - norm_weights (list of tensors): Weights for the SPADE layers.
+        """
+        if not self.mul_ref_label:
+            # Get fixed output size for fc layers.
+            x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x)
+
+        in_ch = self.num_filters_each_layer[i]
+        out_ch = self.num_filters_each_layer[i + 1]
+        spade_ch = self.spade_in_channels[i]
+        eks, sks = self.embed_kernel_size, self.kernel_size
+
+        b = x.size(0)
+        weight_reshaper = WeightReshaper()
+        x = weight_reshaper.reshape_embed_input(x)
+
+        # Weights for the label embedding network.
+        embedding_weights = None
+        if self.use_hyper_embed:
+            fc_e = getattr(self, 'fc_spade_e_' + str(i))(x).view(b, -1)
+            if 'decoder' in self.embed_arch:
+                weight_shape = [in_ch, out_ch, eks, eks]
+                fc_e = fc_e[:, :-in_ch]
+            else:
+                weight_shape = [out_ch, in_ch, eks, eks]
+            embedding_weights = weight_reshaper.reshape_weight(fc_e,
+                                                               weight_shape)
+
+        # Weights for the 3 layers in SPADE module: conv_0, conv_1,
+        # and shortcut.
+        fc_0 = getattr(self, 'fc_spade_0_' + str(i))(x).view(b, -1)
+        fc_1 = getattr(self, 'fc_spade_1_' + str(i))(x).view(b, -1)
+        fc_s = getattr(self, 'fc_spade_s_' + str(i))(x).view(b, -1)
+        if self.conv_before_norm:
+            out_ch = in_ch
+        weight_0 = weight_reshaper.reshape_weight(fc_0, [out_ch * 2, spade_ch,
+                                                         sks, sks])
+        weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch * 2, spade_ch,
+                                                         sks, sks])
+        weight_s = weight_reshaper.reshape_weight(fc_s, [out_ch * 2, spade_ch,
+                                                         sks, sks])
+        norm_weights = [weight_0, weight_1, weight_s]
+
+        return embedding_weights, norm_weights
+
+    def get_conv_weights(self, x, i):
+        r"""Adaptively generate weights for layer i in main branch convolutions.
+
+        Args:
+            x (NxCxHxW tensor): Input features.
+            i (int): Layer index.
+        Returns:
+            (tuple):
+              - conv_weights (list of tensors): Weights for the conv layers in
+                the main branch.
+        """
+        if not self.mul_ref_label:
+            x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x)
+        in_ch = self.num_filters_each_layer[i]
+        out_ch = self.num_filters_each_layer[i + 1]
+        cks = self.conv_kernel_size
+        b = x.size()[0]
+        weight_reshaper = WeightReshaper()
+        x = weight_reshaper.reshape_embed_input(x)
+
+        fc_0 = getattr(self, 'fc_conv_0_' + str(i))(x).view(b, -1)
+        fc_1 = getattr(self, 'fc_conv_1_' + str(i))(x).view(b, -1)
+        fc_s = getattr(self, 'fc_conv_s_' + str(i))(x).view(b, -1)
+        weight_0 = weight_reshaper.reshape_weight(fc_0, [in_ch, out_ch,
+                                                         cks, cks])
+        weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch, in_ch,
+                                                         cks, cks])
+        weight_s = weight_reshaper.reshape_weight(fc_s, [in_ch, out_ch, 1, 1])
+        return [weight_0, weight_1, weight_s]
+
+    def reset(self):
+        r"""Reset the network at the beginning of a sequence."""
+        self.embedding_weights = self.conv_weights = self.norm_weights = None
+
+
+class WeightReshaper():
+    r"""Handles all weight reshape related tasks."""
+    def reshape_weight(self, x, weight_shape):
+        r"""Reshape input x to the desired weight shape.
+
+        Args:
+            x (tensor or list of tensors): Input features.
+            weight_shape (list of int): Desired shape of the weight.
+        Returns:
+            (tuple):
+              - weight (tensor): Network weights
+              - bias (tensor): Network bias.
+        """
+        # If desired shape is a list, first divide x into the target list of
+        # features.
+        if type(weight_shape[0]) == list and type(x) != list:
+            x = self.split_weights(x, self.sum_mul(weight_shape))
+
+        if type(x) == list:
+            return [self.reshape_weight(xi, wi)
+                    for xi, wi in zip(x, weight_shape)]
+
+        # Get output shape, and divide x into either weight + bias or
+        # just weight.
+        weight_shape = [x.size(0)] + weight_shape
+        bias_size = weight_shape[1]
+        try:
+            weight = x[:, :-bias_size].view(weight_shape)
+            bias = x[:, -bias_size:]
+        except Exception:
+            weight = x.view(weight_shape)
+            bias = None
+        return [weight, bias]
+
+    def split_weights(self, weight, sizes):
+        r"""When the desired shape is a list, first divide the input to each
+        corresponding weight shape in the list.
+
+        Args:
+            weight (tensor): Input weight.
+            sizes (int or list of int): Target sizes.
+        Returns:
+            weight (list of tensors): Divided weights.
+        """
+        if isinstance(sizes, list):
+            weights = []
+            cur_size = 0
+            for i in range(len(sizes)):
+                # For each target size in sizes, get the number of elements
+                # needed.
+                next_size = cur_size + self.sum(sizes[i])
+                # Recursively divide the weights.
+                weights.append(self.split_weights(
+                    weight[:, cur_size:next_size], sizes[i]))
+                cur_size = next_size
+            assert (next_size == weight.size(1))
+            return weights
+        return weight
+
+    def reshape_embed_input(self, x):
+        r"""Reshape input to be (B x C) X H X W.
+
+        Args:
+            x (tensor or list of tensors): Input features.
+        Returns:
+            x (tensor or list of tensors): Reshaped features.
+        """
+        if isinstance(x, list):
+            return [self.reshape_embed_input(xi) for xi in zip(x)]
+        b, c, _, _ = x.size()
+        x = x.view(b * c, -1)
+        return x
+
+    def sum(self, x):
+        r"""Sum all elements recursively in a nested list.
+
+        Args:
+            x (nested list of int): Input list of elements.
+        Returns:
+            out (int): Sum of all elements.
+        """
+        if type(x) != list:
+            return x
+        return sum([self.sum(xi) for xi in x])
+
+    def sum_mul(self, x):
+        r"""Given a weight shape, compute the number of elements needed for
+        weight + bias. If input is a list of shapes, sum all the elements.
+
+        Args:
+            x (list of int): Input list of elements.
+        Returns:
+            out (int or list of int): Summed number of elements.
+        """
+        assert (type(x) == list)
+        if type(x[0]) != list:
+            return np.prod(x) + x[0]  # x[0] accounts for bias.
+        return [self.sum_mul(xi) for xi in x]
+
+
+class AttentionModule(nn.Module):
+    r"""Attention module constructor.
+
+    Args:
+       atn_cfg (obj): Generator definition part of the yaml config file.
+       data_cfg (obj): Data definition part of the yaml config file
+       conv_2d_block: Conv2DBlock constructor.
+       num_filters_each_layer (int): The number of filters in each layer.
+    """
+
+    def __init__(self, atn_cfg, data_cfg, conv_2d_block,
+                 num_filters_each_layer):
+        super().__init__()
+        self.initial_few_shot_K = data_cfg.initial_few_shot_K
+        num_input_channels = data_cfg.num_input_channels
+        num_filters = getattr(atn_cfg, 'num_filters', 32)
+
+        self.num_downsample_atn = getattr(atn_cfg, 'num_downsamples', 2)
+        self.atn_query_first = conv_2d_block(num_input_channels, num_filters)
+        self.atn_key_first = conv_2d_block(num_input_channels, num_filters)
+        for i in range(self.num_downsamples_atn):
+            f_in, f_out = num_filters_each_layer[i], \
+                num_filters_each_layer[i + 1]
+            setattr(self, 'atn_key_%d' % i,
+                    conv_2d_block(f_in, f_out, stride=2))
+            setattr(self, 'atn_query_%d' % i,
+                    conv_2d_block(f_in, f_out, stride=2))
+
+    def forward(self, in_features, label, ref_label, attention=None):
+        r"""Get the attention map to combine multiple image features in the
+        case of multiple reference images.
+
+        Args:
+            in_features ((NxK)xC1xH1xW1 tensor): Input feaures.
+            label (NxC2xH2xW2 tensor): Target label.
+            ref_label (NxC2xH2xW2 tensor): Reference label.
+            attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps.
+        Returns:
+            (tuple):
+              - out_features (NxC1xH1xW1 tensor): Attention-combined features.
+              - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps.
+              - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores.
+        """
+        b, c, h, w = in_features.size()
+        k = self.initial_few_shot_K
+        b = b // k
+
+        if attention is None:
+            # Compute the attention map by encoding ref_label and label as
+            # key and query. The map represents how much energy for the k-th
+            # map at location (h_i, w_j) can contribute to the final map at
+            # location (h_i2, w_j2).
+            atn_key = self.attention_encode(ref_label, 'atn_key')
+            atn_query = self.attention_encode(label, 'atn_query')
+
+            atn_key = atn_key.view(b, k, c, -1).permute(
+                0, 1, 3, 2).contiguous().view(b, -1, c)  # B X KHW X C
+            atn_query = atn_query.view(b, c, -1)  # B X C X HW
+            energy = torch.bmm(atn_key, atn_query)  # B X KHW X HW
+            attention = nn.Softmax(dim=1)(energy)
+
+        # Combine the K features from different ref images into one by using
+        # the attention map.
+        in_features = in_features.view(b, k, c, h * w).permute(
+            0, 2, 1, 3).contiguous().view(b, c, -1)  # B X C X KHW
+        out_features = torch.bmm(in_features, attention).view(b, c, h, w)
+
+        # Get a slice of the attention map for visualization.
+        atn_vis = attention.view(b, k, h * w, h * w).sum(2).view(b, k, h, w)
+        return out_features, attention, atn_vis[-1:, 0:1]
+
+    def attention_encode(self, img, net_name):
+        r"""Encode the input image to get the attention map.
+
+        Args:
+            img (NxCxHxW tensor): Input image.
+            net_name (str): Name for attention network.
+        Returns:
+            x (NxC2xH2xW2 tensor): Encoded feature.
+        """
+        x = getattr(self, net_name + '_first')(img)
+        for i in range(self.num_downsample_atn):
+            x = getattr(self, net_name + '_' + str(i))(x)
+        return x
+
+
+class FlowGenerator(nn.Module):
+    r"""flow generator constructor.
+
+    Args:
+       flow_cfg (obj): Flow definition part of the yaml config file.
+       data_cfg (obj): Data definition part of the yaml config file.
+       num_frames (int): Number of input frames.
+    """
+
+    def __init__(self, flow_cfg, data_cfg, num_frames):
+        super().__init__()
+        num_input_channels = data_cfg.num_input_channels
+        if num_input_channels == 0:
+            num_input_channels = 1
+        num_prev_img_channels = get_paired_input_image_channel_number(data_cfg)
+        num_downsamples = getattr(flow_cfg, 'num_downsamples', 3)
+        kernel_size = getattr(flow_cfg, 'kernel_size', 3)
+        padding = kernel_size // 2
+        num_blocks = getattr(flow_cfg, 'num_blocks', 6)
+        num_filters = getattr(flow_cfg, 'num_filters', 32)
+        max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
+        num_filters_each_layer = [min(max_num_filters, num_filters * (2 ** i))
+                                  for i in range(num_downsamples + 1)]
+
+        self.flow_output_multiplier = getattr(flow_cfg,
+                                              'flow_output_multiplier', 20)
+        self.sep_up_mask = getattr(flow_cfg, 'sep_up_mask', False)
+        activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
+                                       'sync_batch')
+        weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')
+
+        base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size,
+                                  padding=padding,
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type=activation_norm_type,
+                                  nonlinearity='leakyrelu')
+
+        num_input_channels = num_input_channels * num_frames + \
+            num_prev_img_channels * (num_frames - 1)
+        # First layer.
+        down_flow = [base_conv_block(num_input_channels, num_filters)]
+
+        # Downsamples.
+        for i in range(num_downsamples):
+            down_flow += [base_conv_block(num_filters_each_layer[i],
+                                          num_filters_each_layer[i + 1],
+                                          stride=2)]
+
+        # Resnet blocks.
+        res_flow = []
+        ch = num_filters_each_layer[num_downsamples]
+        for i in range(num_blocks):
+            res_flow += [
+                Res2dBlock(ch, ch, kernel_size, padding=padding,
+                           weight_norm_type=weight_norm_type,
+                           activation_norm_type=activation_norm_type,
+                           order='NACNAC')]
+
+        # Upsamples.
+        up_flow = []
+        for i in reversed(range(num_downsamples)):
+            up_flow += [nn.Upsample(scale_factor=2),
+                        base_conv_block(num_filters_each_layer[i + 1],
+                                        num_filters_each_layer[i])]
+
+        conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)]
+        conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding,
+                                 nonlinearity='sigmoid')]
+
+        self.down_flow = nn.Sequential(*down_flow)
+        self.res_flow = nn.Sequential(*res_flow)
+        self.up_flow = nn.Sequential(*up_flow)
+        if self.sep_up_mask:
+            self.up_mask = nn.Sequential(*copy.deepcopy(up_flow))
+        self.conv_flow = nn.Sequential(*conv_flow)
+        self.conv_mask = nn.Sequential(*conv_mask)
+
+    def forward(self, label, ref_label, ref_image):
+        r"""Flow generator forward.
+
+        Args:
+            label (4D tensor) : Input label tensor.
+            ref_label (4D tensor) : Reference label tensors.
+            ref_image (4D tensor) : Reference image tensors.
+        Returns:
+            (tuple):
+              - flow (4D tensor) : Generated flow map.
+              - mask (4D tensor) : Generated occlusion mask.
+        """
+        label_concat = torch.cat([label, ref_label, ref_image], dim=1)
+        downsample = self.down_flow(label_concat)
+        res = self.res_flow(downsample)
+        flow_feat = self.up_flow(res)
+        flow = self.conv_flow(flow_feat) * self.flow_output_multiplier
+
+        mask_feat = self.up_mask(res) if self.sep_up_mask else flow_feat
+        mask = self.conv_mask(mask_feat)
+        return flow, mask
+
+
+class LabelEmbedder(nn.Module):
+    r"""Embed the input label map to get embedded features.
+
+    Args:
+        emb_cfg (obj): Embed network configuration.
+        num_input_channels (int): Number of input channels.
+        num_hyper_layers (int): Number of hyper layers.
+    """
+
+    def __init__(self, emb_cfg, num_input_channels, num_hyper_layers=0):
+        super().__init__()
+        num_filters = getattr(emb_cfg, 'num_filters', 32)
+        max_num_filters = getattr(emb_cfg, 'max_num_filters', 1024)
+        self.arch = getattr(emb_cfg, 'arch', 'encoderdecoder')
+        self.num_downsamples = num_downsamples = \
+            getattr(emb_cfg, 'num_downsamples', 5)
+        kernel_size = getattr(emb_cfg, 'kernel_size', 3)
+        weight_norm_type = getattr(emb_cfg, 'weight_norm_type', 'spectral')
+        activation_norm_type = getattr(emb_cfg, 'activation_norm_type', 'none')
+
+        self.unet = 'unet' in self.arch
+        self.has_decoder = 'decoder' in self.arch or self.unet
+        self.num_hyper_layers = num_hyper_layers \
+            if num_hyper_layers != -1 else num_downsamples
+
+        base_conv_block = partial(HyperConv2dBlock, kernel_size=kernel_size,
+                                  padding=(kernel_size // 2),
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type=activation_norm_type,
+                                  nonlinearity='leakyrelu')
+
+        ch = [min(max_num_filters, num_filters * (2 ** i))
+              for i in range(num_downsamples + 1)]
+
+        self.conv_first = base_conv_block(num_input_channels, num_filters,
+                                          activation_norm_type='none')
+
+        # Downsample.
+        for i in range(num_downsamples):
+            is_hyper_conv = (i < num_hyper_layers) and not self.has_decoder
+            setattr(self, 'down_%d' % i,
+                    base_conv_block(ch[i], ch[i + 1], stride=2,
+                                    is_hyper_conv=is_hyper_conv))
+
+        # Upsample.
+        if self.has_decoder:
+            self.upsample = nn.Upsample(scale_factor=2)
+            for i in reversed(range(num_downsamples)):
+                ch_i = ch[i + 1] * (
+                    2 if self.unet and i != num_downsamples - 1 else 1)
+                setattr(self, 'up_%d' % i,
+                        base_conv_block(ch_i, ch[i],
+                                        is_hyper_conv=(i < num_hyper_layers)))
+
+    def forward(self, input, weights=None):
+        r"""Embedding network forward.
+
+        Args:
+            input (NxCxHxW tensor): Network input.
+            weights (list of tensors): Conv weights if using hyper network.
+        Returns:
+            output (list of tensors): Network outputs at different layers.
+        """
+        if input is None:
+            return None
+        output = [self.conv_first(input)]
+
+        for i in range(self.num_downsamples):
+            layer = getattr(self, 'down_%d' % i)
+            # For hyper networks, the hyper layers are at the last few layers
+            # of decoder (if the network has a decoder). Otherwise, the hyper
+            # layers will be at the first few layers of the network.
+            if i >= self.num_hyper_layers or self.has_decoder:
+                conv = layer(output[-1])
+            else:
+                conv = layer(output[-1], conv_weights=weights[i])
+            # We will use outputs from different layers as input to different
+            # SPADE layers in the main branch.
+            output.append(conv)
+
+        if not self.has_decoder:
+            return output
+
+        # If the network has a decoder, will use outputs from the decoder
+        # layers instead of the encoding layers.
+        if not self.unet:
+            output = [output[-1]]
+
+        for i in reversed(range(self.num_downsamples)):
+            input_i = output[-1]
+            if self.unet and i != self.num_downsamples - 1:
+                input_i = torch.cat([input_i, output[i + 1]], dim=1)
+
+            input_i = self.upsample(input_i)
+            layer = getattr(self, 'up_%d' % i)
+            # The last few layers will be hyper layers if necessary.
+            if i >= self.num_hyper_layers:
+                conv = layer(input_i)
+            else:
+                conv = layer(input_i, conv_weights=weights[i])
+            output.append(conv)
+
+        if self.unet:
+            output = output[self.num_downsamples:]
+        return output[::-1]
diff --git a/imaginaire/generators/funit.py b/imaginaire/generators/funit.py
new file mode 100644
index 0000000000000000000000000000000000000000..6520166a4f906afe3c5cd2fda09a6fbc11502213
--- /dev/null
+++ b/imaginaire/generators/funit.py
@@ -0,0 +1,400 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from functools import partial
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+
+from imaginaire.layers import \
+    (Conv2dBlock, LinearBlock, Res2dBlock, UpRes2dBlock)
+
+
+class Generator(nn.Module):
+    r"""Generator of the improved FUNIT baseline in the COCO-FUNIT paper.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super().__init__()
+        self.generator = FUNITTranslator(**vars(gen_cfg))
+
+    def forward(self, data):
+        r"""In the FUNIT's forward pass, it generates a content embedding and
+        a style code from the content image, and a style code from the style
+        image. By mixing the content code and the style code from the content
+        image, we reconstruct the input image. By mixing the content code and
+        the style code from the style image, we have a translation output.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        content_a = self.generator.content_encoder(data['images_content'])
+        style_a = self.generator.style_encoder(data['images_content'])
+        style_b = self.generator.style_encoder(data['images_style'])
+        images_trans = self.generator.decode(content_a, style_b)
+        images_recon = self.generator.decode(content_a, style_a)
+
+        net_G_output = dict(images_trans=images_trans,
+                            images_recon=images_recon)
+        return net_G_output
+
+    def inference(self, data, keep_original_size=True):
+        r"""COCO-FUNIT inference.
+
+        Args:
+            data (dict): Training data at the current iteration.
+              - images_content (tensor): Content images.
+              - images_style (tensor): Style images.
+            a2b (bool): If ``True``, translates images from domain A to B,
+                otherwise from B to A.
+            keep_original_size (bool): If ``True``, output image is resized
+            to the input content image size.
+        """
+        content_a = self.generator.content_encoder(data['images_content'])
+        style_b = self.generator.style_encoder(data['images_style'])
+        output_images = self.generator.decode(content_a, style_b)
+        if keep_original_size:
+            height = data['original_h_w'][0][0]
+            width = data['original_h_w'][0][1]
+            # print('( H, W) = ( %d, %d)' % (height, width))
+            output_images = torch.nn.functional.interpolate(
+                output_images, size=[height, width])
+        file_names = data['key']['images_content'][0]
+        return output_images, file_names
+
+
+class FUNITTranslator(nn.Module):
+    r"""
+
+    Args:
+         num_filters (int): Base filter numbers.
+         num_filters_mlp (int): Base filter number in the MLP module.
+         style_dims (int): Dimension of the style code.
+         num_res_blocks (int): Number of residual blocks at the end of the
+            content encoder.
+         num_mlp_blocks (int): Number of layers in the MLP module.
+         num_downsamples_content (int): Number of times we reduce
+            resolution by 2x2 for the content image.
+         num_downsamples_style (int): Number of times we reduce
+            resolution by 2x2 for the style image.
+         num_image_channels (int): Number of input image channels.
+         weight_norm_type (str): Type of weight normalization.
+             ``'none'``, ``'spectral'``, or ``'weight'``.
+    """
+
+    def __init__(self,
+                 num_filters=64,
+                 num_filters_mlp=256,
+                 style_dims=64,
+                 num_res_blocks=2,
+                 num_mlp_blocks=3,
+                 num_downsamples_style=4,
+                 num_downsamples_content=2,
+                 num_image_channels=3,
+                 weight_norm_type='',
+                 **kwargs):
+        super().__init__()
+
+        self.style_encoder = StyleEncoder(num_downsamples_style,
+                                          num_image_channels,
+                                          num_filters,
+                                          style_dims,
+                                          'reflect',
+                                          'none',
+                                          weight_norm_type,
+                                          'relu')
+
+        self.content_encoder = ContentEncoder(num_downsamples_content,
+                                              num_res_blocks,
+                                              num_image_channels,
+                                              num_filters,
+                                              'reflect',
+                                              'instance',
+                                              weight_norm_type,
+                                              'relu')
+
+        self.decoder = Decoder(self.content_encoder.output_dim,
+                               num_filters_mlp,
+                               num_image_channels,
+                               num_downsamples_content,
+                               'reflect',
+                               weight_norm_type,
+                               'relu')
+
+        self.mlp = MLP(style_dims,
+                       num_filters_mlp,
+                       num_filters_mlp,
+                       num_mlp_blocks,
+                       'none',
+                       'relu')
+
+    def forward(self, images):
+        r"""Reconstruct the input image by combining the computer content and
+        style code.
+
+        Args:
+            images (tensor): Input image tensor.
+        """
+        # reconstruct an image
+        content, style = self.encode(images)
+        images_recon = self.decode(content, style)
+        return images_recon
+
+    def encode(self, images):
+        r"""Encoder images to get their content and style codes.
+
+        Args:
+            images (tensor): Input image tensor.
+        """
+        style = self.style_encoder(images)
+        content = self.content_encoder(images)
+        return content, style
+
+    def decode(self, content, style):
+        r"""Generate images by combining their content and style codes.
+
+        Args:
+            content (tensor): Content code tensor.
+            style (tensor): Style code tensor.
+        """
+        style = self.mlp(style)
+        images = self.decoder(content, style)
+        return images
+
+
+class Decoder(nn.Module):
+    r"""Improved FUNIT decoder.
+
+    Args:
+        num_enc_output_channels (int): Number of content feature channels.
+        style_channels (int): Dimension of the style code.
+        num_image_channels (int): Number of image channels.
+        num_upsamples (int): How many times we are going to apply
+            upsample residual block.
+    """
+
+    def __init__(self,
+                 num_enc_output_channels,
+                 style_channels,
+                 num_image_channels=3,
+                 num_upsamples=4,
+                 padding_type='reflect',
+                 weight_norm_type='none',
+                 nonlinearity='relu'):
+        super(Decoder, self).__init__()
+        adain_params = SimpleNamespace(
+            activation_norm_type='instance',
+            activation_norm_params=SimpleNamespace(affine=False),
+            cond_dims=style_channels)
+
+        base_res_block = partial(Res2dBlock,
+                                 kernel_size=3,
+                                 padding=1,
+                                 padding_mode=padding_type,
+                                 nonlinearity=nonlinearity,
+                                 activation_norm_type='adaptive',
+                                 activation_norm_params=adain_params,
+                                 weight_norm_type=weight_norm_type,
+                                 learn_shortcut=False)
+
+        base_up_res_block = partial(UpRes2dBlock,
+                                    kernel_size=5,
+                                    padding=2,
+                                    padding_mode=padding_type,
+                                    weight_norm_type=weight_norm_type,
+                                    activation_norm_type='adaptive',
+                                    activation_norm_params=adain_params,
+                                    skip_activation_norm='instance',
+                                    skip_nonlinearity=nonlinearity,
+                                    nonlinearity=nonlinearity,
+                                    hidden_channels_equal_out_channels=True,
+                                    learn_shortcut=True)
+
+        dims = num_enc_output_channels
+
+        # Residual blocks with AdaIN.
+        self.decoder = nn.ModuleList()
+        self.decoder += [base_res_block(dims, dims)]
+        self.decoder += [base_res_block(dims, dims)]
+        for _ in range(num_upsamples):
+            self.decoder += [base_up_res_block(dims, dims // 2)]
+            dims = dims // 2
+        self.decoder += [Conv2dBlock(dims,
+                                     num_image_channels,
+                                     kernel_size=7,
+                                     stride=1,
+                                     padding=3,
+                                     padding_mode='reflect',
+                                     nonlinearity='tanh')]
+
+    def forward(self, x, style):
+        r"""
+
+        Args:
+            x (tensor): Content embedding of the content image.
+            style (tensor): Style embedding of the style image.
+        """
+        for block in self.decoder:
+            if getattr(block, 'conditional', False):
+                x = block(x, style)
+            else:
+                x = block(x)
+        return x
+
+
+class StyleEncoder(nn.Module):
+    r"""Improved FUNIT Style Encoder. This is basically the same as the
+    original FUNIT Style Encoder.
+
+    Args:
+        num_downsamples (int): Number of times we reduce resolution by
+            2x2.
+        image_channels (int): Number of input image channels.
+        num_filters (int): Base filter number.
+        style_channels (int): Style code dimension.
+        padding_mode (str): Padding mode.
+        activation_norm_type (str): Type of activation normalization.
+        weight_norm_type (str): Type of weight normalization.
+            ``'none'``, ``'spectral'``, or ``'weight'``.
+        nonlinearity (str): Nonlinearity.
+    """
+
+    def __init__(self,
+                 num_downsamples,
+                 image_channels,
+                 num_filters,
+                 style_channels,
+                 padding_mode,
+                 activation_norm_type,
+                 weight_norm_type,
+                 nonlinearity):
+        super().__init__()
+        conv_params = dict(padding_mode=padding_mode,
+                           activation_norm_type=activation_norm_type,
+                           weight_norm_type=weight_norm_type,
+                           nonlinearity=nonlinearity,
+                           inplace_nonlinearity=True)
+        model = []
+        model += [Conv2dBlock(image_channels, num_filters, 7, 1, 3,
+                              **conv_params)]
+        for i in range(2):
+            model += [Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1,
+                                  **conv_params)]
+            num_filters *= 2
+        for i in range(num_downsamples - 2):
+            model += [Conv2dBlock(num_filters, num_filters, 4, 2, 1,
+                                  **conv_params)]
+        model += [nn.AdaptiveAvgPool2d(1)]
+        model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)]
+        self.model = nn.Sequential(*model)
+        self.output_dim = num_filters
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input image.
+        """
+        return self.model(x)
+
+
+class ContentEncoder(nn.Module):
+    r"""Improved FUNIT Content Encoder. This is basically the same as the
+    original FUNIT content encoder.
+
+    Args:
+        num_downsamples (int): Number of times we reduce resolution by
+           2x2.
+        num_res_blocks (int): Number of times we append residual block
+           after all the downsampling modules.
+        image_channels (int): Number of input image channels.
+        num_filters (int): Base filter number.
+        padding_mode (str): Padding mode
+        activation_norm_type (str): Type of activation normalization.
+        weight_norm_type (str): Type of weight normalization.
+            ``'none'``, ``'spectral'``, or ``'weight'``.
+        nonlinearity (str): Nonlinearity.
+    """
+
+    def __init__(self,
+                 num_downsamples,
+                 num_res_blocks,
+                 image_channels,
+                 num_filters,
+                 padding_mode,
+                 activation_norm_type,
+                 weight_norm_type,
+                 nonlinearity):
+        super().__init__()
+        conv_params = dict(padding_mode=padding_mode,
+                           activation_norm_type=activation_norm_type,
+                           weight_norm_type=weight_norm_type,
+                           nonlinearity=nonlinearity,
+                           inplace_nonlinearity=True,
+                           order='CNACNA')
+        model = []
+        model += [Conv2dBlock(image_channels, num_filters, 7, 1, 3,
+                              **conv_params)]
+        dims = num_filters
+        for i in range(num_downsamples):
+            model += [Conv2dBlock(dims, dims * 2, 4, 2, 1, **conv_params)]
+            dims *= 2
+
+        for _ in range(num_res_blocks):
+            model += [Res2dBlock(dims, dims, learn_shortcut=False, **conv_params)]
+        self.model = nn.Sequential(*model)
+        self.output_dim = dims
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input image.
+        """
+        return self.model(x)
+
+
+class MLP(nn.Module):
+    r"""Improved FUNIT style decoder.
+
+    Args:
+        input_dim (int): Input dimension (style code dimension).
+        output_dim (int): Output dimension (to be fed into the AdaIN
+           layer).
+        latent_dim (int): Latent dimension.
+        num_layers (int): Number of layers in the MLP.
+        activation_norm_type (str): Activation type.
+        nonlinearity (str): Nonlinearity type.
+    """
+
+    def __init__(self,
+                 input_dim,
+                 output_dim,
+                 latent_dim,
+                 num_layers,
+                 activation_norm_type,
+                 nonlinearity):
+        super().__init__()
+        model = []
+        model += [LinearBlock(input_dim, latent_dim,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity)]
+        # changed from num_layers - 2 to num_layers - 3.
+        for i in range(num_layers - 3):
+            model += [LinearBlock(latent_dim, latent_dim,
+                                  activation_norm_type=activation_norm_type,
+                                  nonlinearity=nonlinearity)]
+        model += [LinearBlock(latent_dim, output_dim,
+                              activation_norm_type=activation_norm_type,
+                              nonlinearity=nonlinearity)]
+        self.model = nn.Sequential(*model)
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+        """
+        return self.model(x.view(x.size(0), -1))
diff --git a/imaginaire/generators/gancraft.py b/imaginaire/generators/gancraft.py
new file mode 100644
index 0000000000000000000000000000000000000000..94fc34bee88e31fcdcf48f715f4d17f3de5bc37b
--- /dev/null
+++ b/imaginaire/generators/gancraft.py
@@ -0,0 +1,538 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import cv2
+import imageio
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import imaginaire.model_utils.gancraft.camctl as camctl
+import imaginaire.model_utils.gancraft.mc_utils as mc_utils
+import imaginaire.model_utils.gancraft.voxlib as voxlib
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.generators.gancraft_base import Base3DGenerator, RenderMLP  # noqa
+
+
+class Generator(Base3DGenerator):
+    r"""GANcraft generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super(Generator, self).__init__(gen_cfg, data_cfg)
+        print('GANcraft generator initialization.')
+
+        # Load voxels of the input world.
+        # The loaded voxel tensor has a shape of [X, Y, Z], dtype==torch.int32
+        # 0 means empty (air).
+        print('[Generator] Loading voxel world: ', gen_cfg.voxel_path)
+        if gen_cfg.voxel_path.endswith('.npy'):
+            voxel_t = np.load(gen_cfg.voxel_path)
+            voxel_t = torch.from_numpy(voxel_t.astype(np.int32))
+        else:
+            voxel_t = mc_utils.load_voxel_new(gen_cfg.voxel_path, shape=gen_cfg.voxel_shape)
+        print('[Generator] Loaded voxel world.')
+        self.voxel = mc_utils.McVoxel(voxel_t, preproc_ver=gen_cfg.voxel_preproc_ver)
+        blk_feats = torch.empty([self.voxel.nfilledvox, gen_cfg.blk_feat_dim], requires_grad=True)
+        self.blk_feats = nn.Parameter(blk_feats)  # Feature per voxel corner.
+
+        # Minecraft -> SPADE label translator.
+        self.label_trans = mc_utils.MCLabelTranslator()
+        self.num_reduced_labels = self.label_trans.get_num_reduced_lbls()
+        self.reduced_label_set = getattr(gen_cfg, 'reduced_label_set', False)
+        self.use_label_smooth = getattr(gen_cfg, 'use_label_smooth', False)
+        self.use_label_smooth_real = getattr(gen_cfg, 'use_label_smooth_real', self.use_label_smooth)
+        self.use_label_smooth_pgt = getattr(gen_cfg, 'use_label_smooth_pgt', False)
+        self.label_smooth_dia = getattr(gen_cfg, 'label_smooth_dia', 11)
+
+        # Load MLP model.
+        self.render_net = globals()[gen_cfg.mlp_model](
+            self.input_dim, viewdir_dim=self.input_dim_viewdir, style_dim=self.interm_style_dims,
+            mask_dim=self.num_reduced_labels, out_channels_s=1, out_channels_c=self.final_feat_dim,
+            **self.mlp_model_kwargs)
+
+        # Camera sampler.
+        self.camera_sampler_type = getattr(gen_cfg, 'camera_sampler_type', "random")
+        assert self.camera_sampler_type in ['random', 'traditional']
+        self.camera_min_entropy = getattr(gen_cfg, 'camera_min_entropy', -1)
+        self.camera_rej_avg_depth = getattr(gen_cfg, 'camera_rej_avg_depth', -1)
+        self.cam_res = gen_cfg.cam_res
+        self.crop_size = gen_cfg.crop_size
+
+        print('Done with the GANcraft generator initialization.')
+
+    def custom_init(self):
+        r"""Weight initialization of GANcraft components."""
+
+        self.blk_feats.data.uniform_(-1, 1)
+
+        def init_func(m):
+            if hasattr(m, 'weight'):
+                nn.init.kaiming_normal_(m.weight.data, a=0.2, nonlinearity='leaky_relu')
+                m.weight.data *= 0.5
+            if hasattr(m, 'bias') and m.bias is not None:
+                m.bias.data.fill_(0.0)
+        self.apply(init_func)
+
+    def _get_batch(self, batch_size, device):
+        r"""Sample camera poses and perform ray-voxel intersection.
+
+        Args:
+            batch_size (int): Expected batch size of the current batch
+            device (torch.device): Device on which the tensors should be stored
+        """
+        with torch.no_grad():
+            voxel_id_batch = []
+            depth2_batch = []
+            raydirs_batch = []
+            cam_ori_t_batch = []
+            for b in range(batch_size):
+                while True:  # Rejection sampling.
+                    # Sample camera pose.
+                    if self.camera_sampler_type == 'random':
+                        cam_res = self.cam_res
+                        cam_ori_t, cam_dir_t, cam_up_t = camctl.rand_camera_pose_thridperson2(self.voxel)
+                        # ~24mm fov horizontal.
+                        cam_f = 0.5/np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1)
+                        cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2]
+                        cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad]
+                        cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop)
+                    elif self.camera_sampler_type == 'traditional':
+                        cam_res = self.cam_res
+                        cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2]
+                        dice = torch.rand(1).item()
+                        if dice > 0.5:
+                            cam_ori_t, cam_dir_t, cam_up_t, cam_f = \
+                                camctl.rand_camera_pose_tour(self.voxel)
+                            cam_f = cam_f * (cam_res[1]-1)
+                        else:
+                            cam_ori_t, cam_dir_t, cam_up_t = \
+                                camctl.rand_camera_pose_thridperson2(self.voxel)
+                            # ~24mm fov horizontal.
+                            cam_f = 0.5 / np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1)
+
+                        cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad]
+                        cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop)
+                    else:
+                        raise NotImplementedError(
+                            'Unknown self.camera_sampler_type: {}'.format(self.camera_sampler_type))
+                    # Run ray-voxel intersection test
+                    r"""Ray-voxel intersection CUDA kernel.
+                    Note: voxel_id = 0 and depth2 = NaN if there is no intersection along the ray
+
+                    Args:
+                        voxel_t (Y x 512 x 512 tensor, int32): Full 3D voxel of MC block IDs.
+                        cam_ori_t (3 tensor): Camera origin.
+                        cam_dir_t (3 tensor): Camera direction.
+                        cam_up_t (3 tensor): Camera up vector.
+                        cam_f (float): Camera focal length (in pixels).
+                        cam_c  (list of 2 floats [x, y]): Camera optical center.
+                        img_dims (list of 2 ints [H, W]): Camera resolution.
+                        max_samples (int): Maximum number of blocks intersected along the ray before stopping.
+                    Returns:
+                        voxel_id (    img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors
+                        along each ray
+                        depth2   (2 x img_dims[0] x img_dims[1] x max_samples x 1 tensor): Depths of entrance and exit
+                        points for each ray-voxel intersection.
+                        raydirs  (    img_dims[0] x img_dims[1] x 1 x 3 tensor): The direction of each ray.
+
+                    """
+                    voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
+                        self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, cam_res_crop,
+                        self.num_blocks_early_stop)
+
+                    if self.camera_rej_avg_depth > 0:
+                        depth_map = depth2[0, :, :, 0, :]
+                        avg_depth = torch.mean(depth_map[~torch.isnan(depth_map)])
+                        if avg_depth < self.camera_rej_avg_depth:
+                            continue
+
+                    # Reject low entropy.
+                    if self.camera_min_entropy > 0:
+                        # Check entropy.
+                        maskcnt = torch.bincount(
+                            torch.flatten(voxel_id[:, :, 0, 0]), weights=None, minlength=680).float() / \
+                            (voxel_id.size(0)*voxel_id.size(1))
+                        maskentropy = -torch.sum(maskcnt * torch.log(maskcnt+1e-10))
+                        if maskentropy < self.camera_min_entropy:
+                            continue
+                    break
+
+                voxel_id_batch.append(voxel_id)
+                depth2_batch.append(depth2)
+                raydirs_batch.append(raydirs)
+                cam_ori_t_batch.append(cam_ori_t)
+            voxel_id = torch.stack(voxel_id_batch, dim=0)
+            depth2 = torch.stack(depth2_batch, dim=0)
+            raydirs = torch.stack(raydirs_batch, dim=0)
+            cam_ori_t = torch.stack(cam_ori_t_batch, dim=0).to(device)
+            cam_poses = None
+        return voxel_id, depth2, raydirs, cam_ori_t, cam_poses
+
+    def get_pseudo_gt(self, pseudo_gen, voxel_id, z=None, style_img=None, resize_512=True, deterministic=False):
+        r"""Evaluating img2img network to obtain pseudo-ground truth images.
+
+        Args:
+            pseudo_gen (callable): Function converting mask to image using img2img network.
+            voxel_id (N x img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors along
+            each ray.
+            z (N x C tensor): Optional style code passed to pseudo_gen.
+            style_img (N x 3 x H x W tensor): Optional style image passed to pseudo_gen.
+            resize_512 (bool): If True, evaluate pseudo_gen at 512x512 regardless of input resolution.
+            deterministic (bool): If True, disable stochastic label mapping.
+        """
+        with torch.no_grad():
+            mc_mask = voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long()
+            coco_mask = self.label_trans.mc2coco(mc_mask) - 1
+            coco_mask[coco_mask < 0] = 183
+
+            if not deterministic:
+                # Stochastic mapping
+                dice = torch.rand(1).item()
+                if dice > 0.5 and dice < 0.9:
+                    coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('clouds')
+                elif dice >= 0.9:
+                    coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('fog')
+                dice = torch.rand(1).item()
+                if dice > 0.33 and dice < 0.66:
+                    coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('sea')
+                elif dice >= 0.66:
+                    coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('river')
+
+            fake_masks = torch.zeros([coco_mask.size(0), 185, coco_mask.size(2), coco_mask.size(3)],
+                                     dtype=torch.half, device=voxel_id.device)
+            fake_masks.scatter_(1, coco_mask, 1.0)
+
+            if self.use_label_smooth_pgt:
+                fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia)
+            if self.pad > 0:
+                fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+
+            # Generate pseudo GT using GauGAN.
+            if resize_512:
+                fake_masks_512 = F.interpolate(fake_masks, size=[512, 512], mode='nearest')
+            else:
+                fake_masks_512 = fake_masks
+            pseudo_real_img = pseudo_gen(fake_masks_512, z=z, style_img=style_img)
+
+            # NaN Inf Guard. NaN can occure on Volta GPUs.
+            nan_mask = torch.isnan(pseudo_real_img)
+            inf_mask = torch.isinf(pseudo_real_img)
+            pseudo_real_img[nan_mask | inf_mask] = 0.0
+            if resize_512:
+                pseudo_real_img = F.interpolate(
+                    pseudo_real_img, size=[fake_masks.size(2), fake_masks.size(3)], mode='area')
+            pseudo_real_img = torch.clamp(pseudo_real_img, -1, 1)
+
+        return pseudo_real_img, fake_masks
+
+    def sample_camera(self, data, pseudo_gen):
+        r"""Sample camera randomly and precompute everything used by both Gen and Dis.
+
+        Args:
+            data (dict):
+                images (N x 3 x H x W tensor) : Real images
+                label (N x C2 x H x W tensor) : Segmentation map
+            pseudo_gen (callable): Function converting mask to image using img2img network.
+        Returns:
+            ret (dict):
+                voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray.
+                depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel
+                intersection.
+                raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
+                cam_ori_t (N x 3 tensor): Camera origins.
+                pseudo_real_img (N x 3 x H x W tensor): Pseudo-ground truth image.
+                real_masks (N x C3 x H x W tensor): One-hot segmentation map for real images, with translated labels.
+                fake_masks (N x C3 x H x W tensor): One-hot segmentation map for sampled camera views.
+        """
+        device = torch.device('cuda')
+        batch_size = data['images'].size(0)
+        # ================ Assemble a batch ==================
+        # Requires: voxel_id, depth2, raydirs, cam_ori_t.
+        voxel_id, depth2, raydirs, cam_ori_t, _ = self._get_batch(batch_size, device)
+        ret = {'voxel_id': voxel_id, 'depth2': depth2, 'raydirs': raydirs, 'cam_ori_t': cam_ori_t}
+
+        if pseudo_gen is not None:
+            pseudo_real_img, _ = self.get_pseudo_gt(pseudo_gen, voxel_id)
+        ret['pseudo_real_img'] = pseudo_real_img.float()
+
+        # =============== Mask translation ================
+        real_masks = data['label']
+        if self.reduced_label_set:
+            # Translate fake mask (directly from mcid).
+            # convert unrecognized labels to 'dirt'.
+            # N C H W [1 1 80 80]
+            reduce_fake_mask = self.label_trans.mc2reduced(
+                voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long(), ign2dirt=True)
+            reduce_fake_mask_onehot = torch.zeros([
+                reduce_fake_mask.size(0), self.num_reduced_labels, reduce_fake_mask.size(2), reduce_fake_mask.size(3)],
+                dtype=torch.float, device=device)
+            reduce_fake_mask_onehot.scatter_(1, reduce_fake_mask, 1.0)
+            fake_masks = reduce_fake_mask_onehot
+            if self.pad != 0:
+                fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+
+            # Translate real mask (data['label']), which is onehot.
+            real_masks_idx = torch.argmax(real_masks, dim=1, keepdim=True)
+            real_masks_idx[real_masks_idx > 182] = 182
+
+            reduced_real_mask = self.label_trans.coco2reduced(real_masks_idx)
+            reduced_real_mask_onehot = torch.zeros([
+                reduced_real_mask.size(0), self.num_reduced_labels, reduced_real_mask.size(2),
+                reduced_real_mask.size(3)], dtype=torch.float, device=device)
+            reduced_real_mask_onehot.scatter_(1, reduced_real_mask, 1.0)
+            real_masks = reduced_real_mask_onehot
+
+        # Mask smoothing.
+        if self.use_label_smooth:
+            fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia)
+        if self.use_label_smooth_real:
+            real_masks = mc_utils.segmask_smooth(real_masks, kernel_size=self.label_smooth_dia)
+
+        ret['real_masks'] = real_masks
+        ret['fake_masks'] = fake_masks
+
+        return ret
+
+    def forward(self, data, random_style=False):
+        r"""GANcraft Generator forward.
+
+        Args:
+            data (dict):
+                images (N x 3 x H x W tensor) : Real images
+                voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray.
+                depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel
+                intersection.
+                raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
+                cam_ori_t (N x 3 tensor): Camera origins.
+            random_style (bool): Whether to sample a random style vector.
+        Returns:
+            output (dict):
+                fake_images (N x 3 x H x W tensor): fake images
+                mu (N x C1 tensor): mean vectors
+                logvar (N x C1 tensor): log-variance vectors
+        """
+        device = torch.device('cuda')
+        batch_size = data['images'].size(0)
+
+        # ================ Assemble a batch ==================
+        # Requires: voxel_id, depth2, raydirs, cam_ori_t.
+        voxel_id, depth2, raydirs, cam_ori_t = data['voxel_id'], data['depth2'], data['raydirs'], data['cam_ori_t']
+        if 'pseudo_real_img' in data:
+            pseudo_real_img = data['pseudo_real_img']
+
+        z, mu, logvar = None, None, None
+        if random_style:
+            if self.style_dims > 0:
+                z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device)
+        else:
+            if self.style_encoder is None:
+                # ================ Get Style Code =================
+                if self.style_dims > 0:
+                    z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device)
+            else:
+                mu, logvar, z = self.style_encoder(pseudo_real_img)
+
+        # ================ Network Forward ================
+        # Forward StyleNet
+        if self.style_net is not None:
+            z = self.style_net(z)
+
+        # Forward per-pixel net.
+        net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, nosky_mask, \
+            sky_mask, sky_only_mask, new_idx = self._forward_perpix(
+                self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z)
+
+        # Forward global net.
+        fake_images, fake_images_raw = self._forward_global(net_out, z)
+        if self.pad != 0:
+            fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+
+        # =============== Arrange Return Values ================
+        output = {}
+        output['fake_images'] = fake_images
+        output['mu'] = mu
+        output['logvar'] = logvar
+        return output
+
+    def inference(self,
+                  output_dir,
+                  camera_mode,
+                  style_img_path=None,
+                  seed=1,
+                  pad=30,
+                  num_samples=40,
+                  num_blocks_early_stop=6,
+                  sample_depth=3,
+                  tile_size=128,
+                  resolution_hw=[540, 960],
+                  cam_ang=72,
+                  cam_maxstep=10):
+        r"""Compute result images according to the provided camera trajectory and save the results in the specified
+        folder. The full image is evaluated in multiple tiles to save memory.
+
+        Args:
+            output_dir (str): Where should the results be stored.
+            camera_mode (int): Which camera trajectory to use.
+            style_img_path (str): Path to the style-conditioning image.
+            seed (int): Random seed (controls style when style_image_path is not specified).
+            pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the
+            receptive field of the CNN to avoid border artifact.
+            num_samples (int): Number of samples per ray (different from training).
+            num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping
+            (different from training).
+            sample_depth (float): Max distance traveled through boxes before stopping (different from training).
+            tile_size (int): Max size of a tile in pixels.
+            resolution_hw (list [H, W]): Resolution of the output image.
+            cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller).
+            cam_maxstep (int): Number of frames sampled from the camera trajectory.
+        """
+
+        def write_img(path, img, rgb_input=False):
+            img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8)
+            img = img[0].transpose(1, 2, 0)
+            if rgb_input:
+                img = img[..., [2, 1, 0]]
+            cv2.imwrite(path, img,  [cv2.IMWRITE_PNG_COMPRESSION, 4])
+            return img[..., ::-1]
+
+        def read_img(path):
+            img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255
+            img = img * 2 - 1
+            img = torch.from_numpy(img)
+
+        print('Saving to', output_dir)
+
+        # Use provided random seed.
+        device = torch.device('cuda')
+        rng_cuda = torch.Generator(device=device)
+        rng_cuda = rng_cuda.manual_seed(seed)
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed(seed)
+
+        self.pad = pad
+        self.num_samples = num_samples
+        self.num_blocks_early_stop = num_blocks_early_stop
+        self.sample_depth = sample_depth
+
+        self.coarse_deterministic_sampling = True
+        self.crop_size = resolution_hw
+        self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad]
+        self.use_label_smooth_pgt = False
+
+        # Make output dirs.
+        gancraft_outputs_dir = os.path.join(output_dir, 'gancraft_outputs')
+        os.makedirs(gancraft_outputs_dir, exist_ok=True)
+        vis_masks_dir = os.path.join(output_dir, 'vis_masks')
+        os.makedirs(vis_masks_dir, exist_ok=True)
+        fout = imageio.get_writer(gancraft_outputs_dir + '.mp4', fps=1)
+        fout_cat = imageio.get_writer(gancraft_outputs_dir + '-vis_masks.mp4', fps=1)
+
+        evalcamctl = camctl.EvalCameraController(
+            self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang,
+            smooth_decay_multiplier=150/cam_maxstep)
+
+        # Get output style.
+        if style_img_path is None:
+            z = torch.empty(1, self.style_dims, dtype=torch.float32, device=device)
+            z.normal_(generator=rng_cuda)
+        else:
+            style_img = read_img(style_img_path)
+            style_img = style_img.to(device).unsqueeze(0)
+            mu, logvar, z = self.style_encoder(style_img)
+        z = self.style_net(z)
+
+        # Generate required output images.
+        for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl):
+            print('Rendering frame', id)
+            cam_f = cam_f * (self.crop_size[1]-1)  # So that the view is not depending on the padding
+            cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2]
+
+            voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
+                self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res,
+                self.num_blocks_early_stop)
+
+            voxel_id = voxel_id.unsqueeze(0)
+            depth2 = depth2.unsqueeze(0)
+            raydirs = raydirs.unsqueeze(0)
+            cam_ori_t = cam_ori_t.unsqueeze(0).to(device)
+
+            # Save 3D voxel rendering.
+            mc_rgb = self.label_trans.mc_color(voxel_id[0, :, :, 0, 0].cpu().numpy())
+            # Diffused shading, co-located light.
+            first_intersection_depth = depth2[:, 0, :, :, 0, None, :]  # [1, 542, 542, 1, 1].
+            first_intersection_point = raydirs * first_intersection_depth + cam_ori_t[:, None, None, None, :]
+            fip_local_coords = torch.remainder(first_intersection_point, 1.0)
+            fip_wall_proximity = torch.minimum(fip_local_coords, 1.0-fip_local_coords)
+            fip_wall_orientation = torch.argmin(fip_wall_proximity, dim=-1, keepdim=False)
+            # 0: [1,0,0]; 1: [0,1,0]; 2: [0,0,1]
+            lut = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32,
+                               device=fip_wall_orientation.device)
+            fip_normal = lut[fip_wall_orientation]  # [1, 542, 542, 1, 3]
+            diffuse_shade = torch.abs(torch.sum(fip_normal * raydirs, dim=-1))
+
+            mc_rgb = (mc_rgb.astype(np.float) / 255) ** 2.2
+            mc_rgb = mc_rgb * diffuse_shade[0, :, :, :].cpu().numpy()
+            mc_rgb = (mc_rgb ** (1/2.2)) * 255
+            mc_rgb = mc_rgb.astype(np.uint8)
+            if self.pad > 0:
+                mc_rgb = mc_rgb[self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+            cv2.imwrite(os.path.join(vis_masks_dir, '{:05d}.png'.format(id)), mc_rgb,  [cv2.IMWRITE_PNG_COMPRESSION, 4])
+
+            # Tiled eval of GANcraft.
+            voxel_id_all = voxel_id
+            depth2_all = depth2
+            raydirs_all = raydirs
+
+            # Evaluate sky in advance to get a consistent sky in the semi-transparent region.
+            if self.sky_global_avgpool:
+                sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
+                sky_raydirs_in = voxlib.positional_encoding(
+                    sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
+                skynet_out_c = self.sky_net(sky_raydirs_in, z)
+                sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
+                self.sky_avg = sky_avg
+
+            num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size
+            num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size
+
+            fake_images_chunks_v = []
+            # For each horizontal strip.
+            for strip_id_h in range(num_strips_h):
+                strip_begin_h = strip_id_h * tile_size
+                strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0])
+                # For each vertical strip.
+                fake_images_chunks_h = []
+                for strip_id_w in range(num_strips_w):
+                    strip_begin_w = strip_id_w * tile_size
+                    strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1])
+
+                    voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
+                    depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
+                    raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
+
+                    net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
+                        nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix(
+                            self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z)
+                    fake_images, _ = self._forward_global(net_out, z)
+
+                    if self.pad != 0:
+                        fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+                    fake_images_chunks_h.append(fake_images)
+                fake_images_h = torch.cat(fake_images_chunks_h, dim=-1)
+                fake_images_chunks_v.append(fake_images_h)
+            fake_images = torch.cat(fake_images_chunks_v, dim=-2)
+            rgb = write_img(os.path.join(gancraft_outputs_dir,
+                            '{:05d}.png'.format(id)), fake_images, rgb_input=True)
+            fout.append_data(rgb)
+            fout_cat.append_data(np.concatenate((mc_rgb[..., ::-1], rgb), axis=1))
+        fout.close()
+        fout_cat.close()
diff --git a/imaginaire/generators/gancraft_base.py b/imaginaire/generators/gancraft_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef164b67053af5d228bda6c9aea75c261d7b114f
--- /dev/null
+++ b/imaginaire/generators/gancraft_base.py
@@ -0,0 +1,603 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+import re
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.layers import Conv2dBlock, LinearBlock
+from imaginaire.model_utils.gancraft.layers import AffineMod, ModLinear
+import imaginaire.model_utils.gancraft.mc_utils as mc_utils
+import imaginaire.model_utils.gancraft.voxlib as voxlib
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class RenderMLP(nn.Module):
+    r""" MLP with affine modulation."""
+
+    def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680,
+                 out_channels_s=1, out_channels_c=3, hidden_channels=256,
+                 use_seg=True):
+        super(RenderMLP, self).__init__()
+
+        self.use_seg = use_seg
+        if self.use_seg:
+            self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False)
+
+        self.fc_viewdir = None
+        if viewdir_dim > 0:
+            self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False)
+
+        self.fc_1 = nn.Linear(in_channels, hidden_channels)
+
+        self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
+        self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
+        self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
+
+        self.fc_sigma = nn.Linear(hidden_channels, out_channels_s)
+
+        if viewdir_dim > 0:
+            self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False)
+            self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True)
+        else:
+            self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim,
+                                  bias=False, mod_bias=True, output_mode=True)
+        self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
+        self.fc_out_c = nn.Linear(hidden_channels, out_channels_c)
+
+        self.act = nn.LeakyReLU(negative_slope=0.2)
+
+    def forward(self, x, raydir, z, m):
+        r""" Forward network
+
+        Args:
+            x (N x H x W x M x in_channels tensor): Projected features.
+            raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions.
+            z (N x style_dim tensor): Style codes.
+            m (N x H x W x M x mask_dim tensor): One-hot segmentation maps.
+        """
+        b, h, w, n, _ = x.size()
+        z = z[:, None, None, None, :]
+
+        f = self.fc_1(x)
+        if self.use_seg:
+            f = f + self.fc_m_a(m)
+        # Common MLP
+        f = self.act(f)
+        f = self.act(self.fc_2(f, z))
+        f = self.act(self.fc_3(f, z))
+        f = self.act(self.fc_4(f, z))
+
+        # Sigma MLP
+        sigma = self.fc_sigma(f)
+
+        # Color MLP
+        if self.fc_viewdir is not None:
+            f = self.fc_5(f)
+            f = f + self.fc_viewdir(raydir)
+            f = self.act(self.mod_5(f, z))
+        else:
+            f = self.act(self.fc_5(f, z))
+        f = self.act(self.fc_6(f, z))
+        c = self.fc_out_c(f)
+        return sigma, c
+
+
+class StyleMLP(nn.Module):
+    r"""MLP converting style code to intermediate style representation."""
+
+    def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True,
+                 output_act=True):
+        super(StyleMLP, self).__init__()
+
+        self.normalize_input = normalize_input
+        self.output_act = output_act
+        fc_layers = []
+        fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True))
+        for i in range(num_layers-1):
+            fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True))
+        self.fc_layers = nn.ModuleList(fc_layers)
+
+        self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True)
+
+        if leaky_relu:
+            self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        else:
+            self.act = functools.partial(F.relu, inplace=True)
+
+    def forward(self, z):
+        r""" Forward network
+
+        Args:
+            z (N x style_dim tensor): Style codes.
+        """
+        if self.normalize_input:
+            z = F.normalize(z, p=2, dim=-1)
+        for fc_layer in self.fc_layers:
+            z = self.act(fc_layer(z))
+        z = self.fc_out(z)
+        if self.output_act:
+            z = self.act(z)
+        return z
+
+
+class SKYMLP(nn.Module):
+    r"""MLP converting ray directions to sky features."""
+
+    def __init__(self, in_channels, style_dim, out_channels_c=3,
+                 hidden_channels=256, leaky_relu=True):
+        super(SKYMLP, self).__init__()
+        self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False)
+
+        self.fc1 = nn.Linear(in_channels, hidden_channels)
+        self.fc2 = nn.Linear(hidden_channels, hidden_channels)
+        self.fc3 = nn.Linear(hidden_channels, hidden_channels)
+        self.fc4 = nn.Linear(hidden_channels, hidden_channels)
+        self.fc5 = nn.Linear(hidden_channels, hidden_channels)
+
+        self.fc_out_c = nn.Linear(hidden_channels, out_channels_c)
+
+        if leaky_relu:
+            self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        else:
+            self.act = functools.partial(F.relu, inplace=True)
+
+    def forward(self, x, z):
+        r"""Forward network
+
+        Args:
+            x (... x in_channels tensor): Ray direction embeddings.
+            z (... x style_dim tensor): Style codes.
+        """
+
+        z = self.fc_z_a(z)
+        while z.dim() < x.dim():
+            z = z.unsqueeze(1)
+
+        y = self.act(self.fc1(x) + z)
+        y = self.act(self.fc2(y))
+        y = self.act(self.fc3(y))
+        y = self.act(self.fc4(y))
+        y = self.act(self.fc5(y))
+        c = self.fc_out_c(y)
+
+        return c
+
+
+class RenderCNN(nn.Module):
+    r"""CNN converting intermediate feature map to final image."""
+
+    def __init__(self, in_channels, style_dim, hidden_channels=256,
+                 leaky_relu=True):
+        super(RenderCNN, self).__init__()
+        self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels)
+
+        self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0)
+        self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
+        self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
+
+        self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
+        self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
+
+        self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
+        self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
+
+        self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0)
+
+        if leaky_relu:
+            self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        else:
+            self.act = functools.partial(F.relu, inplace=True)
+
+    def modulate(self, x, w, b):
+        w = w[..., None, None]
+        b = b[..., None, None]
+        return x * (w+1) + b
+
+    def forward(self, x, z):
+        r"""Forward network.
+
+        Args:
+            x (N x in_channels x H x W tensor): Intermediate feature map
+            z (N x style_dim tensor): Style codes.
+        """
+        z = self.fc_z_cond(z)
+        adapt = torch.chunk(z, 2 * 2, dim=-1)
+
+        y = self.act(self.conv1(x))
+
+        y = y + self.conv2b(self.act(self.conv2a(y)))
+        y = self.act(self.modulate(y, adapt[0], adapt[1]))
+
+        y = y + self.conv3b(self.act(self.conv3a(y)))
+        y = self.act(self.modulate(y, adapt[2], adapt[3]))
+
+        y = y + self.conv4b(self.act(self.conv4a(y)))
+        y = self.act(y)
+
+        y = self.conv4(y)
+
+        return y
+
+
+class StyleEncoder(nn.Module):
+    r"""Style Encoder constructor.
+
+    Args:
+        style_enc_cfg (obj): Style encoder definition file.
+    """
+
+    def __init__(self, style_enc_cfg):
+        super(StyleEncoder, self).__init__()
+        input_image_channels = style_enc_cfg.input_image_channels
+        num_filters = style_enc_cfg.num_filters
+        kernel_size = style_enc_cfg.kernel_size
+        padding = int(np.ceil((kernel_size - 1.0) / 2))
+        style_dims = style_enc_cfg.style_dims
+        weight_norm_type = style_enc_cfg.weight_norm_type
+        self.no_vae = getattr(style_enc_cfg, 'no_vae', False)
+        activation_norm_type = 'none'
+        nonlinearity = 'leakyrelu'
+        base_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              stride=2,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              # inplace_nonlinearity=True,
+                              nonlinearity=nonlinearity)
+        self.layer1 = base_conv2d_block(input_image_channels, num_filters)
+        self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
+        self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
+        self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
+        self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
+        self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
+        self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
+        if not self.no_vae:
+            self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
+
+    def forward(self, input_x):
+        r"""SPADE Style Encoder forward.
+
+        Args:
+            input_x (N x 3 x H x W tensor): input images.
+        Returns:
+            mu (N x C tensor): Mean vectors.
+            logvar (N x C tensor): Log-variance vectors.
+            z (N x C tensor): Style code vectors.
+        """
+        if input_x.size(2) != 256 or input_x.size(3) != 256:
+            input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
+        x = self.layer1(input_x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.layer5(x)
+        x = self.layer6(x)
+        x = x.view(x.size(0), -1)
+        mu = self.fc_mu(x)
+        if not self.no_vae:
+            logvar = self.fc_var(x)
+            std = torch.exp(0.5 * logvar)
+            eps = torch.randn_like(std)
+            z = eps.mul(std) + mu
+        else:
+            z = mu
+            logvar = torch.zeros_like(mu)
+        return mu, logvar, z
+
+
+class Base3DGenerator(nn.Module):
+    r"""Minecraft 3D generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super(Base3DGenerator, self).__init__()
+        print('Base3DGenerator initialization.')
+
+        # ---------------------- Main Network ------------------------
+        # Exclude some of the features from positional encoding
+        self.pe_no_pe_feat_dim = getattr(gen_cfg, 'pe_no_pe_feat_dim', 0)
+
+        # blk_feat passes through PE
+        input_dim = (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)*(gen_cfg.pe_lvl_feat*2) + self.pe_no_pe_feat_dim
+        if (gen_cfg.pe_incl_orig_feat):
+            input_dim += (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)
+        print('[Base3DGenerator] Expected input dimensions: ', input_dim)
+        self.input_dim = input_dim
+
+        self.mlp_model_kwargs = gen_cfg.mlp_model_kwargs
+        self.pe_lvl_localcoords = getattr(gen_cfg, 'pe_lvl_localcoords', 0)
+        if self.pe_lvl_localcoords > 0:
+            self.mlp_model_kwargs['poscode_dim'] = self.pe_lvl_localcoords * 2 * 3
+
+        # Set pe_lvl_raydir=0 and pe_incl_orig_raydir=False to disable view direction input
+        input_dim_viewdir = 3*(gen_cfg.pe_lvl_raydir*2)
+        if (gen_cfg.pe_incl_orig_raydir):
+            input_dim_viewdir += 3
+        print('[Base3DGenerator] Expected viewdir input dimensions: ', input_dim_viewdir)
+        self.input_dim_viewdir = input_dim_viewdir
+
+        self.pe_params = [gen_cfg.pe_lvl_feat, gen_cfg.pe_incl_orig_feat,
+                          gen_cfg.pe_lvl_raydir, gen_cfg.pe_incl_orig_raydir]
+
+        # Style input dimension
+        style_dims = gen_cfg.style_dims
+        self.style_dims = style_dims
+        interm_style_dims = getattr(gen_cfg, 'interm_style_dims', style_dims)
+        self.interm_style_dims = interm_style_dims
+        # ---------------------- Style MLP --------------------------
+        self.style_net = globals()[gen_cfg.stylenet_model](
+            style_dims, interm_style_dims, **gen_cfg.stylenet_model_kwargs)
+
+        # number of output channels for MLP (before blending)
+        final_feat_dim = getattr(gen_cfg, 'final_feat_dim', 16)
+        self.final_feat_dim = final_feat_dim
+
+        # ----------------------- Sky Network -------------------------
+        sky_input_dim_base = 3
+        # Dedicated sky network input dimensions
+        sky_input_dim = sky_input_dim_base*(gen_cfg.pe_lvl_raydir_sky*2)
+        if (gen_cfg.pe_incl_orig_raydir_sky):
+            sky_input_dim += sky_input_dim_base
+        print('[Base3DGenerator] Expected sky input dimensions: ', sky_input_dim)
+        self.pe_params_sky = [gen_cfg.pe_lvl_raydir_sky, gen_cfg.pe_incl_orig_raydir_sky]
+        self.sky_net = SKYMLP(sky_input_dim, style_dim=interm_style_dims, out_channels_c=final_feat_dim)
+
+        # ----------------------- Style Encoder -------------------------
+        style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
+        setattr(style_enc_cfg, 'input_image_channels', 3)
+        setattr(style_enc_cfg, 'style_dims', gen_cfg.style_dims)
+        self.style_encoder = StyleEncoder(style_enc_cfg)
+
+        # ---------------------- Ray Caster -------------------------
+        self.num_blocks_early_stop = gen_cfg.num_blocks_early_stop
+        self.num_samples = gen_cfg.num_samples
+        self.sample_depth = gen_cfg.sample_depth
+        self.coarse_deterministic_sampling = getattr(gen_cfg, 'coarse_deterministic_sampling', True)
+        self.sample_use_box_boundaries = getattr(gen_cfg, 'sample_use_box_boundaries', True)
+
+        # ---------------------- Blender -------------------------
+        self.raw_noise_std = getattr(gen_cfg, 'raw_noise_std', 0.0)
+        self.dists_scale = getattr(gen_cfg, 'dists_scale', 0.25)
+        self.clip_feat_map = getattr(gen_cfg, 'clip_feat_map', True)
+        self.keep_sky_out = getattr(gen_cfg, 'keep_sky_out', False)
+        self.keep_sky_out_avgpool = getattr(gen_cfg, 'keep_sky_out_avgpool', False)
+        keep_sky_out_learnbg = getattr(gen_cfg, 'keep_sky_out_learnbg', False)
+        self.sky_global_avgpool = getattr(gen_cfg, 'sky_global_avgpool', False)
+        if self.keep_sky_out:
+            self.sky_replace_color = None
+            if keep_sky_out_learnbg:
+                sky_replace_color = torch.zeros([final_feat_dim])
+                sky_replace_color.requires_grad = True
+                self.sky_replace_color = torch.nn.Parameter(sky_replace_color)
+        # ---------------------- render_cnn -------------------------
+        self.denoiser = RenderCNN(final_feat_dim, style_dim=interm_style_dims)
+        self.pad = gen_cfg.pad
+
+    def get_param_groups(self, cfg_opt):
+        print('[Generator] get_param_groups')
+
+        if hasattr(cfg_opt, 'ignore_parameters'):
+            print('[Generator::get_param_groups] [x]: ignored.')
+            optimize_parameters = []
+            for k, x in self.named_parameters():
+                match = False
+                for m in cfg_opt.ignore_parameters:
+                    if re.match(m, k) is not None:
+                        match = True
+                        print(' [x]', k)
+                        break
+                if match is False:
+                    print(' [v]', k)
+                    optimize_parameters.append(x)
+        else:
+            optimize_parameters = self.parameters()
+
+        param_groups = []
+        param_groups.append({'params': optimize_parameters})
+
+        if hasattr(cfg_opt, 'param_groups'):
+            optimized_param_names = []
+            all_param_names = [k for k, v in self.named_parameters()]
+            param_groups = []
+            for k, v in cfg_opt.param_groups.items():
+                print('[Generator::get_param_groups] Adding param group from config:', k, v)
+                params = getattr(self, k)
+                named_parameters = [k]
+                if issubclass(type(params), nn.Module):
+                    named_parameters = [k+'.'+pname for pname, _ in params.named_parameters()]
+                    params = params.parameters()
+                param_groups.append({'params': params, **v})
+                optimized_param_names.extend(named_parameters)
+
+        print('[Generator::get_param_groups] UNOPTIMIZED PARAMETERS:\n    ',
+              set(all_param_names) - set(optimized_param_names))
+
+        return param_groups
+
+    def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None):
+        r"""Forwarding the MLP.
+
+        Args:
+            blk_feats (K x C1 tensor): Sparse block features.
+            worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points.
+            raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings.
+            z (N x C3 tensor): Intermediate style vectors.
+            mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps.
+        Returns:
+            net_out_s (N x H x W x L x 1 tensor): Opacities.
+            net_out_c (N x H x W x L x C5 tensor): Color embeddings.
+        """
+        proj_feature = voxlib.sparse_trilinear_interp_worldcoord(
+            blk_feats, self.voxel.corner_t, worldcoord2, ign_zero=True)
+
+        render_net_extra_kwargs = {}
+        if self.pe_lvl_localcoords > 0:
+            local_coords = torch.remainder(worldcoord2, 1.0) * 2.0
+            # Scale to [0, 2], as the positional encoding function doesn't have internal x2
+            local_coords[torch.isnan(local_coords)] = 0.0
+            local_coords = local_coords.contiguous()
+            poscode = voxlib.positional_encoding(local_coords, self.pe_lvl_localcoords, -1, False)
+            render_net_extra_kwargs['poscode'] = poscode
+
+        if self.pe_params[0] == 0 and self.pe_params[1] is True:  # no PE shortcut, saves ~400MB
+            feature_in = proj_feature
+        else:
+            if self.pe_no_pe_feat_dim > 0:
+                feature_in = voxlib.positional_encoding(
+                    proj_feature[..., :-self.pe_no_pe_feat_dim].contiguous(), self.pe_params[0], -1, self.pe_params[1])
+                feature_in = torch.cat([feature_in, proj_feature[..., -self.pe_no_pe_feat_dim:]], dim=-1)
+            else:
+                feature_in = voxlib.positional_encoding(
+                    proj_feature.contiguous(), self.pe_params[0], -1, self.pe_params[1])
+
+        net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot, **render_net_extra_kwargs)
+
+        if self.raw_noise_std > 0.:
+            noise = torch.randn_like(net_out_s) * self.raw_noise_std
+            net_out_s = net_out_s + noise
+
+        return net_out_s, net_out_c
+
+    def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z):
+        r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features
+
+        Args:
+            blk_feats (K x C1 tensor): Sparse block features.
+            voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels
+            depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection.
+            raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
+            cam_ori_t (N x 3 tensor): Camera origins.
+            z (N x C3 tensor): Intermediate style vectors.
+        """
+        # Generate sky_mask; PE transform on ray direction.
+        with torch.no_grad():
+            raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
+            if self.pe_params[2] == 0 and self.pe_params[3] is True:
+                raydirs_in = raydirs_in
+            elif self.pe_params[2] == 0 and self.pe_params[3] is False:  # Not using raydir at all
+                raydirs_in = None
+            else:
+                raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3])
+
+            # sky_mask: when True, ray finally hits sky
+            sky_mask = voxel_id[:, :, :, [-1], :] == 0
+            # sky_only_mask: when True, ray hits nothing but sky
+            sky_only_mask = voxel_id[:, :, :, [0], :] == 0
+
+        with torch.no_grad():
+            # Random sample points along the ray
+            num_samples = self.num_samples + 1
+            if self.sample_use_box_boundaries:
+                num_samples = self.num_samples - self.num_blocks_early_stop
+
+            # 10 samples per ray + 4 intersections - 2
+            rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched(
+                depth2, num_samples, deterministic=self.coarse_deterministic_sampling,
+                use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth)
+
+            worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :]
+
+            # Generate per-sample segmentation label
+            voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True)
+            mc_masks = torch.gather(voxel_id_reduced, -2, new_idx)  # B 256 256 N 1
+            mc_masks = mc_masks.long()
+            mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size(
+                2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device)
+            # mc_masks_onehot: [B H W Nlayer 680]
+            mc_masks_onehot.scatter_(-1, mc_masks, 1.0)
+
+        net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot)
+
+        # Handle sky
+        sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
+        sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
+        skynet_out_c = self.sky_net(sky_raydirs_in, z)
+
+        # Blending
+        weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2)
+
+        # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero.
+        weights = weights * torch.logical_not(sky_only_mask).float()
+        total_weights_raw = torch.sum(weights, dim=-2, keepdim=True)  # 256 256 1 1
+        total_weights = total_weights_raw
+
+        is_gnd = worldcoord2[..., [0]] <= 1.0  # Y X Z, [256, 256, 4, 3], nan < 1.0 == False
+        is_gnd = is_gnd.any(dim=-2, keepdim=True)
+        nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd)
+        nosky_mask = nosky_mask.float()
+
+        # Avoid sky leakage
+        sky_weight = 1.0-total_weights
+        if self.keep_sky_out:
+            # keep_sky_out_avgpool overrides sky_replace_color
+            if self.sky_replace_color is None or self.keep_sky_out_avgpool:
+                if self.keep_sky_out_avgpool:
+                    if hasattr(self, 'sky_avg'):
+                        sky_avg = self.sky_avg
+                    else:
+                        if self.sky_global_avgpool:
+                            sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
+                        else:
+                            skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1)
+                            sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False)
+                            sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2)
+                    # print(sky_avg.shape)
+                    skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask)
+                else:
+                    sky_weight = sky_weight * (1.0-nosky_mask)
+            else:
+                skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask)
+
+        if self.clip_feat_map is True:  # intermediate feature before blending & CNN
+            rgbs = torch.clamp(net_out_c, -1, 1) + 1
+            rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1
+            net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
+                rgbs_sky  # 576, 768, 4, 3 -> 576, 768, 3
+            net_out = net_out.squeeze(-2)
+            net_out = net_out - 1
+        elif self.clip_feat_map is False:
+            rgbs = net_out_c
+            rgbs_sky = skynet_out_c
+            net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
+                rgbs_sky  # 576, 768, 4, 3 -> 576, 768, 3
+            net_out = net_out.squeeze(-2)
+        elif self.clip_feat_map == 'tanh':
+            rgbs = torch.tanh(net_out_c)
+            rgbs_sky = torch.tanh(skynet_out_c)
+            net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
+                rgbs_sky  # 576, 768, 4, 3 -> 576, 768, 3
+            net_out = net_out.squeeze(-2)
+        else:
+            raise NotImplementedError
+
+        return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
+            nosky_mask, sky_mask, sky_only_mask, new_idx
+
+    def _forward_global(self, net_out, z):
+        r"""Forward the CNN
+
+        Args:
+            net_out (N x C5 x H x W tensor): Intermediate feature maps.
+            z (N x C3 tensor): Intermediate style vectors.
+
+        Returns:
+            fake_images (N x 3 x H x W tensor): Output image.
+            fake_images_raw (N x 3 x H x W tensor): Output image before TanH.
+        """
+        fake_images = net_out.permute(0, 3, 1, 2)
+        fake_images_raw = self.denoiser(fake_images, z)
+        fake_images = torch.tanh(fake_images_raw)
+
+        return fake_images, fake_images_raw
diff --git a/imaginaire/generators/munit.py b/imaginaire/generators/munit.py
new file mode 100644
index 0000000000000000000000000000000000000000..55cb066bb51e6f8e96d623d2224af06af017c086
--- /dev/null
+++ b/imaginaire/generators/munit.py
@@ -0,0 +1,465 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+from torch.nn import Upsample as NearestUpsample
+
+from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
+from imaginaire.generators.unit import ContentEncoder
+
+
+class Generator(nn.Module):
+    r"""Improved MUNIT generator.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super().__init__()
+        self.autoencoder_a = AutoEncoder(**vars(gen_cfg))
+        self.autoencoder_b = AutoEncoder(**vars(gen_cfg))
+
+    def forward(self, data, random_style=True, image_recon=True,
+                latent_recon=True, cycle_recon=True, within_latent_recon=False):
+        r"""In MUNIT's forward pass, it generates a content code and a style
+        code from images in both domain. It then performs a within-domain
+        reconstruction step and a cross-domain translation step.
+        In within-domain reconstruction, it reconstructs an image using the
+        content and style from the same image and optionally encodes the image
+        back to the latent space.
+        In cross-domain translation, it generates an translated image by mixing
+        the content and style from images in different domains, and optionally
+        encodes the image back to the latent space.
+
+        Args:
+            data (dict): Training data at the current iteration.
+              - images_a (tensor): Images from domain A.
+              - images_b (tensor): Images from domain B.
+            random_style (bool): If ``True``, samples the style code from the
+                prior distribution, otherwise uses the style code encoded from
+                the input images in the other domain.
+            image_recon (bool): If ``True``, also returns reconstructed images.
+            latent_recon (bool): If ``True``, also returns reconstructed latent
+                code during cross-domain translation.
+            cycle_recon (bool): If ``True``, also returns cycle
+                reconstructed images.
+            within_latent_recon (bool): If ``True``, also returns reconstructed
+                latent code during within-domain reconstruction.
+        """
+
+        images_a = data['images_a']
+        images_b = data['images_b']
+        net_G_output = dict()
+
+        # encode input images into content and style code
+        content_a, style_a = self.autoencoder_a.encode(images_a)
+        content_b, style_b = self.autoencoder_b.encode(images_b)
+
+        # decode (within domain)
+        if image_recon:
+            images_aa = self.autoencoder_a.decode(content_a, style_a)
+            images_bb = self.autoencoder_b.decode(content_b, style_b)
+            net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb))
+
+        # decode (cross domain)
+        if random_style:  # use randomly sampled style code
+            style_a_rand = torch.randn_like(style_a)
+            style_b_rand = torch.randn_like(style_b)
+        else:  # use style code encoded from the other domain
+            style_a_rand = style_a
+            style_b_rand = style_b
+        images_ba = self.autoencoder_a.decode(content_b, style_a_rand)
+        images_ab = self.autoencoder_b.decode(content_a, style_b_rand)
+
+        # encode translated images into content and style code
+        if latent_recon or cycle_recon:
+            content_ba, style_ba = self.autoencoder_a.encode(images_ba)
+            content_ab, style_ab = self.autoencoder_b.encode(images_ab)
+            net_G_output.update(dict(content_ba=content_ba, style_ba=style_ba,
+                                     content_ab=content_ab, style_ab=style_ab))
+
+        # encode reconstructed images into content and style code
+        if image_recon and within_latent_recon:
+            content_aa, style_aa = self.autoencoder_a.encode(images_aa)
+            content_bb, style_bb = self.autoencoder_b.encode(images_bb)
+            net_G_output.update(dict(content_aa=content_aa, style_aa=style_aa,
+                                     content_bb=content_bb, style_bb=style_bb))
+
+        # cycle reconstruction
+        if cycle_recon:
+            images_aba = self.autoencoder_a.decode(content_ab, style_a)
+            images_bab = self.autoencoder_b.decode(content_ba, style_b)
+            net_G_output.update(
+                dict(images_aba=images_aba, images_bab=images_bab))
+
+        # required outputs
+        net_G_output.update(dict(content_a=content_a, content_b=content_b,
+                                 style_a=style_a, style_b=style_b,
+                                 style_a_rand=style_a_rand,
+                                 style_b_rand=style_b_rand,
+                                 images_ba=images_ba, images_ab=images_ab))
+
+        return net_G_output
+
+    def inference(self, data, a2b=True, random_style=True):
+        r"""MUNIT inference.
+
+        Args:
+            data (dict): Training data at the current iteration.
+              - images_a (tensor): Images from domain A.
+              - images_b (tensor): Images from domain B.
+            a2b (bool): If ``True``, translates images from domain A to B,
+                otherwise from B to A.
+            random_style (bool): If ``True``, samples the style code from the
+                prior distribution, otherwise uses the style code encoded from
+                the input images in the other domain.
+        """
+        if a2b:
+            input_key = 'images_a'
+            content_encode = self.autoencoder_a.content_encoder
+            style_encode = self.autoencoder_b.style_encoder
+            decode = self.autoencoder_b.decode
+        else:
+            input_key = 'images_b'
+            content_encode = self.autoencoder_b.content_encoder
+            style_encode = self.autoencoder_a.style_encoder
+            decode = self.autoencoder_a.decode
+
+        content_images = data[input_key]
+        content = content_encode(content_images)
+        if random_style:
+            style_channels = self.autoencoder_a.style_channels
+            style = torch.randn(content.size(0), style_channels, 1, 1,
+                                device=torch.device('cuda'))
+            file_names = data['key'][input_key]['filename']
+        else:
+            style_key = 'images_b' if a2b else 'images_a'
+            assert style_key in data.keys(), \
+                "{} must be provided when 'random_style' " \
+                "is set to False".format(style_key)
+            style_images = data[style_key]
+            style = style_encode(style_images)
+            file_names = \
+                [content_name + '_style_' + style_name
+                 for content_name, style_name in
+                    zip(data['key'][input_key]['filename'],
+                        data['key'][style_key]['filename'])]
+
+        output_images = decode(content, style)
+        return output_images, file_names
+
+
+class AutoEncoder(nn.Module):
+    r"""Improved MUNIT autoencoder.
+
+    Args:
+        num_filters (int): Base filter numbers.
+        max_num_filters (int): Maximum number of filters in the encoder.
+        num_filters_mlp (int): Base filter number in the MLP module.
+        latent_dim (int): Dimension of the style code.
+        num_res_blocks (int): Number of residual blocks at the end of the
+            content encoder.
+        num_mlp_blocks (int): Number of layers in the MLP module.
+        num_downsamples_style (int): Number of times we reduce
+            resolution by 2x2 for the style image.
+        num_downsamples_content (int): Number of times we reduce
+            resolution by 2x2 for the content image.
+        num_image_channels (int): Number of input image channels.
+        content_norm_type (str): Type of activation normalization in the
+            content encoder.
+        style_norm_type (str): Type of activation normalization in the
+            style encoder.
+        decoder_norm_type (str): Type of activation normalization in the
+            decoder.
+        weight_norm_type (str): Type of weight normalization.
+        decoder_norm_params (obj): Parameters of activation normalization in the
+            decoder. If not ``None``, decoder_norm_params.__dict__ will be used
+            as keyword arguments when initializing activation normalization.
+        output_nonlinearity (str): Type of nonlinearity before final output,
+            ``'tanh'`` or ``'none'``.
+        pre_act (bool): If ``True``, uses pre-activation residual blocks.
+        apply_noise (bool): If ``True``, injects Gaussian noise in the decoder.
+    """
+
+    def __init__(self,
+                 num_filters=64,
+                 max_num_filters=256,
+                 num_filters_mlp=256,
+                 latent_dim=8,
+                 num_res_blocks=4,
+                 num_mlp_blocks=2,
+                 num_downsamples_style=4,
+                 num_downsamples_content=2,
+                 num_image_channels=3,
+                 content_norm_type='instance',
+                 style_norm_type='',
+                 decoder_norm_type='instance',
+                 weight_norm_type='',
+                 decoder_norm_params=SimpleNamespace(affine=False),
+                 output_nonlinearity='',
+                 pre_act=False,
+                 apply_noise=False,
+                 **kwargs):
+        super().__init__()
+        for key in kwargs:
+            if key != 'type':
+                warnings.warn(
+                    "Generator argument '{}' is not used.".format(key))
+        self.style_encoder = StyleEncoder(num_downsamples_style,
+                                          num_image_channels,
+                                          num_filters,
+                                          latent_dim,
+                                          'reflect',
+                                          style_norm_type,
+                                          weight_norm_type,
+                                          'relu')
+        self.content_encoder = ContentEncoder(num_downsamples_content,
+                                              num_res_blocks,
+                                              num_image_channels,
+                                              num_filters,
+                                              max_num_filters,
+                                              'reflect',
+                                              content_norm_type,
+                                              weight_norm_type,
+                                              'relu',
+                                              pre_act)
+        self.decoder = Decoder(num_downsamples_content,
+                               num_res_blocks,
+                               self.content_encoder.output_dim,
+                               num_image_channels,
+                               num_filters_mlp,
+                               'reflect',
+                               decoder_norm_type,
+                               decoder_norm_params,
+                               weight_norm_type,
+                               'relu',
+                               output_nonlinearity,
+                               pre_act,
+                               apply_noise)
+        self.mlp = MLP(latent_dim,
+                       num_filters_mlp,
+                       num_filters_mlp,
+                       num_mlp_blocks,
+                       'none',
+                       'relu')
+        self.style_channels = latent_dim
+
+    def forward(self, images):
+        r"""Reconstruct an image.
+
+        Args:
+            images (Tensor): Input images.
+        Returns:
+            images_recon (Tensor): Reconstructed images.
+        """
+        content, style = self.encode(images)
+        images_recon = self.decode(content, style)
+        return images_recon
+
+    def encode(self, images):
+        r"""Encode an image to content and style code.
+
+        Args:
+            images (Tensor): Input images.
+        Returns:
+            (tuple):
+              - content (Tensor): Content code.
+              - style (Tensor): Style code.
+        """
+        style = self.style_encoder(images)
+        content = self.content_encoder(images)
+        return content, style
+
+    def decode(self, content, style):
+        r"""Decode content and style code to an image.
+
+        Args:
+            content (Tensor): Content code.
+            style (Tensor): Style code.
+        Returns:
+            images (Tensor): Output images.
+        """
+        style = self.mlp(style)
+        images = self.decoder(content, style)
+        return images
+
+
+class StyleEncoder(nn.Module):
+    r"""MUNIT style encoder.
+
+    Args:
+        num_downsamples (int): Number of times we reduce
+            resolution by 2x2.
+        num_image_channels (int): Number of input image channels.
+        num_filters (int): Base filter numbers.
+        style_channels (int): Dimension of the style code.
+        padding_mode (string): Type of padding.
+        activation_norm_type (str): Type of activation normalization.
+        weight_norm_type (str): Type of weight normalization.
+        nonlinearity (str): Type of nonlinear activation function.
+    """
+
+    def __init__(self, num_downsamples, num_image_channels, num_filters,
+                 style_channels, padding_mode, activation_norm_type,
+                 weight_norm_type, nonlinearity):
+        super().__init__()
+        conv_params = dict(padding_mode=padding_mode,
+                           activation_norm_type=activation_norm_type,
+                           weight_norm_type=weight_norm_type,
+                           nonlinearity=nonlinearity,
+                           inplace_nonlinearity=True)
+        model = []
+        model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3,
+                              **conv_params)]
+        for i in range(2):
+            model += [Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1,
+                                  **conv_params)]
+            num_filters *= 2
+        for i in range(num_downsamples - 2):
+            model += [Conv2dBlock(num_filters, num_filters, 4, 2, 1,
+                                  **conv_params)]
+        model += [nn.AdaptiveAvgPool2d(1)]
+        model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)]
+        self.model = nn.Sequential(*model)
+        self.output_dim = num_filters
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input image.
+        """
+        return self.model(x)
+
+
+class Decoder(nn.Module):
+    r"""Improved MUNIT decoder. The network consists of
+
+    - $(num_res_blocks) residual blocks.
+    - $(num_upsamples) residual blocks or convolutional blocks
+    - output layer.
+
+    Args:
+        num_upsamples (int): Number of times we increase resolution by 2x2.
+        num_res_blocks (int): Number of residual blocks.
+        num_filters (int): Base filter numbers.
+        num_image_channels (int): Number of input image channels.
+        style_channels (int): Dimension of the style code.
+        padding_mode (string): Type of padding.
+        activation_norm_type (str): Type of activation normalization.
+        activation_norm_params (obj): Parameters of activation normalization.
+            If not ``None``, decoder_norm_params.__dict__ will be used
+            as keyword arguments when initializing activation normalization.
+        weight_norm_type (str): Type of weight normalization.
+        nonlinearity (str): Type of nonlinear activation function.
+        output_nonlinearity (str): Type of nonlinearity before final output,
+            ``'tanh'`` or ``'none'``.
+        pre_act (bool): If ``True``, uses pre-activation residual blocks.
+        apply_noise (bool): If ``True``, injects Gaussian noise.
+    """
+
+    def __init__(self,
+                 num_upsamples,
+                 num_res_blocks,
+                 num_filters,
+                 num_image_channels,
+                 style_channels,
+                 padding_mode,
+                 activation_norm_type,
+                 activation_norm_params,
+                 weight_norm_type,
+                 nonlinearity,
+                 output_nonlinearity,
+                 pre_act=False,
+                 apply_noise=False):
+        super().__init__()
+        adain_params = SimpleNamespace(
+            activation_norm_type=activation_norm_type,
+            activation_norm_params=activation_norm_params,
+            cond_dims=style_channels)
+        conv_params = dict(padding_mode=padding_mode,
+                           nonlinearity=nonlinearity,
+                           inplace_nonlinearity=True,
+                           apply_noise=apply_noise,
+                           weight_norm_type=weight_norm_type,
+                           activation_norm_type='adaptive',
+                           activation_norm_params=adain_params)
+
+        # The order of operations in residual blocks.
+        order = 'pre_act' if pre_act else 'CNACNA'
+
+        # Residual blocks with AdaIN.
+        self.decoder = nn.ModuleList()
+        for _ in range(num_res_blocks):
+            self.decoder += [Res2dBlock(num_filters, num_filters,
+                                        **conv_params,
+                                        order=order)]
+
+        # Convolutional blocks with upsampling.
+        for i in range(num_upsamples):
+            self.decoder += [NearestUpsample(scale_factor=2)]
+            self.decoder += [Conv2dBlock(num_filters, num_filters // 2,
+                                         5, 1, 2, **conv_params)]
+            num_filters //= 2
+        self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3,
+                                     nonlinearity=output_nonlinearity,
+                                     padding_mode=padding_mode)]
+
+    def forward(self, x, style):
+        r"""
+
+        Args:
+            x (tensor): Content embedding of the content image.
+            style (tensor): Style embedding of the style image.
+        """
+        for block in self.decoder:
+            if getattr(block, 'conditional', False):
+                x = block(x, style)
+            else:
+                x = block(x)
+        return x
+
+
+class MLP(nn.Module):
+    r"""The multi-layer perceptron (MLP) that maps Gaussian style code to a
+    feature vector that is given as the conditional input to AdaIN.
+
+    Args:
+        input_dim (int): Number of channels in the input tensor.
+        output_dim (int): Number of channels in the output tensor.
+        latent_dim (int): Number of channels in the latent features.
+        num_layers (int): Number of layers in the MLP.
+        norm (str): Type of activation normalization.
+        nonlinearity (str): Type of nonlinear activation function.
+    """
+
+    def __init__(self, input_dim, output_dim, latent_dim, num_layers,
+                 norm, nonlinearity):
+        super().__init__()
+        model = []
+        model += [LinearBlock(input_dim, latent_dim,
+                              activation_norm_type=norm,
+                              nonlinearity=nonlinearity)]
+        for i in range(num_layers - 2):
+            model += [LinearBlock(latent_dim, latent_dim,
+                                  activation_norm_type=norm,
+                                  nonlinearity=nonlinearity)]
+        model += [LinearBlock(latent_dim, output_dim,
+                              activation_norm_type=norm,
+                              nonlinearity=nonlinearity)]
+        self.model = nn.Sequential(*model)
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input image.
+        """
+        return self.model(x.view(x.size(0), -1))
diff --git a/imaginaire/generators/pix2pixHD.py b/imaginaire/generators/pix2pixHD.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd2e36b31b2b045594d3dd1d7db7cb4ee336d6f8
--- /dev/null
+++ b/imaginaire/generators/pix2pixHD.py
@@ -0,0 +1,348 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import Upsample as NearestUpsample
+
+from imaginaire.layers import Conv2dBlock, Res2dBlock
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+                                   get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Generator(nn.Module):
+    r"""Pix2pixHD coarse-to-fine generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super().__init__()
+        # pix2pixHD has a global generator.
+        global_gen_cfg = gen_cfg.global_generator
+        num_filters_global = getattr(global_gen_cfg, 'num_filters', 64)
+        # Optionally, it can have several local enhancers. They are useful
+        # for generating high resolution images.
+        local_gen_cfg = gen_cfg.local_enhancer
+        self.num_local_enhancers = num_local_enhancers = \
+            getattr(local_gen_cfg, 'num_enhancers', 1)
+        # By default, pix2pixHD using instance normalization.
+        activation_norm_type = getattr(gen_cfg, 'activation_norm_type',
+                                       'instance')
+        activation_norm_params = getattr(gen_cfg, 'activation_norm_params',
+                                         None)
+        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '')
+        padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect')
+        base_conv_block = partial(Conv2dBlock,
+                                  padding_mode=padding_mode,
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type=activation_norm_type,
+                                  activation_norm_params=activation_norm_params,
+                                  nonlinearity='relu')
+        base_res_block = partial(Res2dBlock,
+                                 padding_mode=padding_mode,
+                                 weight_norm_type=weight_norm_type,
+                                 activation_norm_type=activation_norm_type,
+                                 activation_norm_params=activation_norm_params,
+                                 nonlinearity='relu', order='CNACN')
+        # Know what is the number of available segmentation labels.
+        num_input_channels = get_paired_input_label_channel_number(data_cfg)
+        self.concat_features = False
+        # Check whether label input contains specific type of data (e.g.
+        # instance_maps).
+        self.contain_instance_map = False
+        if data_cfg.input_labels[-1] == 'instance_maps':
+            self.contain_instance_map = True
+        # The feature encoder is only useful when the instance map is provided.
+        if hasattr(gen_cfg, 'enc') and self.contain_instance_map:
+            num_feat_channels = getattr(gen_cfg.enc, 'num_feat_channels', 0)
+            if num_feat_channels > 0:
+                num_input_channels += num_feat_channels
+                self.concat_features = True
+                self.encoder = Encoder(gen_cfg.enc, data_cfg)
+
+        # Global generator model.
+        global_model = GlobalGenerator(global_gen_cfg, data_cfg,
+                                       num_input_channels, padding_mode,
+                                       base_conv_block, base_res_block)
+        if num_local_enhancers == 0:
+            self.global_model = global_model
+        else:
+            # Get rid of the last layer.
+            global_model = global_model.model
+            global_model = [global_model[i]
+                            for i in range(len(global_model) - 1)]
+            # global_model = [global_model[i]
+            #                 for i in range(len(global_model) - 2)]
+            self.global_model = nn.Sequential(*global_model)
+
+        # Local enhancer model.
+        for n in range(num_local_enhancers):
+            # num_filters = num_filters_global // (2 ** n)
+            num_filters = num_filters_global // (2 ** (n + 1))
+            output_img = (n == num_local_enhancers - 1)
+            setattr(self, 'enhancer_%d' % n,
+                    LocalEnhancer(local_gen_cfg, data_cfg,
+                                  num_input_channels, num_filters,
+                                  padding_mode, base_conv_block,
+                                  base_res_block, output_img))
+
+        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1],
+                                       count_include_pad=False)
+
+    def forward(self, data, random_style=False):
+        r"""Coarse-to-fine generator forward.
+
+        Args:
+            data (dict) : Dictionary of input data.
+            random_style (bool): Always set to false for the pix2pixHD model.
+        Returns:
+            output (dict) : Dictionary of output data.
+        """
+        label = data['label']
+
+        output = dict()
+        if self.concat_features:
+            features = self.encoder(data['images'], data['instance_maps'])
+            label = torch.cat([label, features], dim=1)
+            output['feature_maps'] = features
+
+        # Create input pyramid.
+        input_downsampled = [label]
+        for i in range(self.num_local_enhancers):
+            input_downsampled.append(self.downsample(input_downsampled[-1]))
+
+        # Output at coarsest level.
+        x = self.global_model(input_downsampled[-1])
+
+        # Coarse-to-fine: build up one layer at a time.
+        for n in range(self.num_local_enhancers):
+            input_n = input_downsampled[self.num_local_enhancers - n - 1]
+            enhancer = getattr(self, 'enhancer_%d' % n)
+            x = enhancer(x, input_n)
+
+        output['fake_images'] = x
+        return output
+
+    def load_pretrained_network(self, pretrained_dict):
+        r"""Load a pretrained network."""
+        # print(pretrained_dict.keys())
+        model_dict = self.state_dict()
+        print('Pretrained network has fewer layers; The following are '
+              'not initialized:')
+
+        not_initialized = set()
+        for k, v in model_dict.items():
+            kp = 'module.' + k.replace('global_model.', 'global_model.model.')
+            if kp in pretrained_dict and v.size() == pretrained_dict[kp].size():
+                model_dict[k] = pretrained_dict[kp]
+            else:
+                not_initialized.add('.'.join(k.split('.')[:2]))
+        print(sorted(not_initialized))
+        self.load_state_dict(model_dict)
+
+    def inference(self, data, **kwargs):
+        r"""Generator inference.
+
+        Args:
+            data (dict) : Dictionary of input data.
+        Returns:
+            fake_images (tensor): Output fake images.
+            file_names (str): Data file name.
+        """
+        output = self.forward(data, **kwargs)
+        return output['fake_images'], data['key']['seg_maps'][0]
+
+
+class LocalEnhancer(nn.Module):
+    r"""Local enhancer constructor. These are sub-networks that are useful
+    when aiming to produce high-resolution outputs.
+
+    Args:
+        gen_cfg (obj): local generator definition part of the yaml config
+        file.
+        data_cfg (obj): Data definition part of the yaml config file.
+        num_input_channels (int): Number of segmentation labels.
+        num_filters (int): Number of filters for the first layer.
+        padding_mode (str): zero | reflect | ...
+        base_conv_block (obj): Conv block with preset attributes.
+        base_res_block (obj): Residual block with preset attributes.
+        output_img (bool): Output is image or feature map.
+    """
+
+    def __init__(self, gen_cfg, data_cfg, num_input_channels, num_filters,
+                 padding_mode, base_conv_block, base_res_block,
+                 output_img=False):
+        super(LocalEnhancer, self).__init__()
+        num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 3)
+        num_img_channels = get_paired_input_image_channel_number(data_cfg)
+        # Downsample.
+        model_downsample = \
+            [base_conv_block(num_input_channels, num_filters, 7, padding=3),
+             base_conv_block(num_filters, num_filters * 2, 3, stride=2,
+                             padding=1)]
+        # Residual blocks.
+        model_upsample = []
+        for i in range(num_res_blocks):
+            model_upsample += [base_res_block(num_filters * 2, num_filters * 2,
+                                              3, padding=1)]
+        # Upsample.
+        model_upsample += \
+            [NearestUpsample(scale_factor=2),
+             base_conv_block(num_filters * 2, num_filters, 3, padding=1)]
+
+        # Final convolution.
+        if output_img:
+            model_upsample += [Conv2dBlock(num_filters, num_img_channels, 7,
+                                           padding=3, padding_mode=padding_mode,
+                                           nonlinearity='tanh')]
+
+        self.model_downsample = nn.Sequential(*model_downsample)
+        self.model_upsample = nn.Sequential(*model_upsample)
+
+    def forward(self, output_coarse, input_fine):
+        r"""Local enhancer forward.
+
+        Args:
+            output_coarse (4D tensor) : Coarse output from previous layer.
+            input_fine (4D tensor) : Fine input from current layer.
+        Returns:
+            output (4D tensor) : Refined output.
+        """
+        output = self.model_upsample(self.model_downsample(input_fine) + output_coarse)
+        return output
+
+
+class GlobalGenerator(nn.Module):
+    r"""Coarse generator constructor. This is the main generator in the
+    pix2pixHD architecture.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+        num_input_channels (int): Number of segmentation labels.
+        padding_mode (str): zero | reflect | ...
+        base_conv_block (obj): Conv block with preset attributes.
+        base_res_block (obj): Residual block with preset attributes.
+    """
+
+    def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode,
+                 base_conv_block, base_res_block):
+        super(GlobalGenerator, self).__init__()
+        num_img_channels = get_paired_input_image_channel_number(data_cfg)
+        num_filters = getattr(gen_cfg, 'num_filters', 64)
+        num_downsamples = getattr(gen_cfg, 'num_downsamples', 4)
+        num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9)
+        # First layer.
+        model = [base_conv_block(num_input_channels, num_filters,
+                                 kernel_size=7, padding=3)]
+        # Downsample.
+        for i in range(num_downsamples):
+            ch = num_filters * (2 ** i)
+            model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
+        # ResNet blocks.
+        ch = num_filters * (2 ** num_downsamples)
+        for i in range(num_res_blocks):
+            model += [base_res_block(ch, ch, 3, padding=1)]
+        # Upsample.
+        num_upsamples = num_downsamples
+        for i in reversed(range(num_upsamples)):
+            ch = num_filters * (2 ** i)
+            model += \
+                [NearestUpsample(scale_factor=2),
+                 base_conv_block(ch * 2, ch, 3, padding=1)]
+        model += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3,
+                              padding_mode=padding_mode, nonlinearity='tanh')]
+        self.model = nn.Sequential(*model)
+
+    def forward(self, input):
+        r"""Coarse-to-fine generator forward.
+
+        Args:
+            input (4D tensor) : Input semantic representations.
+        Returns:
+            output (4D tensor) : Synthesized image by generator.
+        """
+        return self.model(input)
+
+
+class Encoder(nn.Module):
+    r"""Encoder for getting region-wise features for style control.
+
+    Args:
+        enc_cfg (obj): Encoder definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file
+    """
+
+    def __init__(self, enc_cfg, data_cfg):
+        super(Encoder, self).__init__()
+        label_nc = get_paired_input_label_channel_number(data_cfg)
+        feat_nc = enc_cfg.num_feat_channels
+        n_clusters = getattr(enc_cfg, 'num_clusters', 10)
+        for i in range(label_nc):
+            dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32)
+            self.register_buffer('cluster_%d' % i,
+                                 torch.tensor(dummy_arr, dtype=torch.float32))
+        num_img_channels = get_paired_input_image_channel_number(data_cfg)
+        self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3)
+        num_filters = getattr(enc_cfg, 'num_filters', 64)
+        num_downsamples = getattr(enc_cfg, 'num_downsamples', 4)
+        weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none')
+        activation_norm_type = getattr(enc_cfg, 'activation_norm_type',
+                                       'instance')
+        padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect')
+        base_conv_block = partial(Conv2dBlock,
+                                  padding_mode=padding_mode,
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type=activation_norm_type,
+                                  nonlinearity='relu')
+        model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)]
+        # Downsample.
+        for i in range(num_downsamples):
+            ch = num_filters * (2**i)
+            model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)]
+        # Upsample.
+        for i in reversed(range(num_downsamples)):
+            ch = num_filters * (2 ** i)
+            model += [NearestUpsample(scale_factor=2),
+                      base_conv_block(ch * 2, ch, 3, padding=1)]
+
+        model += [Conv2dBlock(num_filters, self.num_feat_channels, 7,
+                              padding=3, padding_mode=padding_mode,
+                              nonlinearity='tanh')]
+        self.model = nn.Sequential(*model)
+
+    def forward(self, input, instance_map):
+        r"""Extracting region-wise features
+
+        Args:
+            input (4D tensor): Real RGB images.
+            instance_map (4D tensor): Instance label mask.
+        Returns:
+            outputs_mean (4D tensor): Instance-wise average-pooled
+                feature maps.
+        """
+        outputs = self.model(input)
+        # Instance-wise average pooling.
+        outputs_mean = torch.zeros_like(outputs)
+        # Find all the unique labels in this batch.
+        inst_list = np.unique(instance_map.cpu().numpy().astype(int))
+        for i in inst_list:
+            for b in range(input.size(0)):
+                # Find the pixels in this instance map have this instance label.
+                indices = (instance_map[b:b+1] == int(i)).nonzero()  # n x 4
+                # Scan through the feature channels.
+                for j in range(self.num_feat_channels):
+                    output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
+                                         indices[:, 2], indices[:, 3]]
+                    mean_feat = torch.mean(output_ins).expand_as(output_ins)
+                    outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
+                                 indices[:, 2], indices[:, 3]] = mean_feat
+        return outputs_mean
diff --git a/imaginaire/generators/spade.py b/imaginaire/generators/spade.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc69630304ccb2ce3fab707ca2e7de5f7aeec55a
--- /dev/null
+++ b/imaginaire/generators/spade.py
@@ -0,0 +1,571 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+import math
+import types
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Upsample as NearestUpsample
+
+from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
+from imaginaire.utils.data import (get_crop_h_w,
+                                   get_paired_input_image_channel_number,
+                                   get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Generator(nn.Module):
+    r"""SPADE generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super(Generator, self).__init__()
+        print('SPADE generator initialization.')
+        # We assume the first datum is the ground truth image.
+        image_channels = getattr(gen_cfg, 'image_channels', None)
+        if image_channels is None:
+            image_channels = get_paired_input_image_channel_number(data_cfg)
+        num_labels = getattr(gen_cfg, 'num_labels', None)
+        if num_labels is None:
+            # Calculate number of channels in the input label when not specified.
+            num_labels = get_paired_input_label_channel_number(data_cfg)
+        crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations)
+        # Build the generator
+        out_image_small_side_size = crop_w if crop_w < crop_h else crop_h
+        num_filters = getattr(gen_cfg, 'num_filters', 128)
+        kernel_size = getattr(gen_cfg, 'kernel_size', 3)
+        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
+
+        cond_dims = 0
+        # Check whether we use the style code.
+        style_dims = getattr(gen_cfg, 'style_dims', None)
+        self.style_dims = style_dims
+        if style_dims is not None:
+            print('\tStyle code dimensions: %d' % style_dims)
+            cond_dims += style_dims
+            self.use_style = True
+        else:
+            self.use_style = False
+        # Check whether we use the attribute code.
+        if hasattr(gen_cfg, 'attribute_dims'):
+            self.use_attribute = True
+            self.attribute_dims = gen_cfg.attribute_dims
+            cond_dims += gen_cfg.attribute_dims
+        else:
+            self.use_attribute = False
+
+        if not self.use_style and not self.use_attribute:
+            self.use_style_encoder = False
+        else:
+            self.use_style_encoder = True
+        print('\tBase filter number: %d' % num_filters)
+        print('\tConvolution kernel size: %d' % kernel_size)
+        print('\tWeight norm type: %s' % weight_norm_type)
+        skip_activation_norm = \
+            getattr(gen_cfg, 'skip_activation_norm', True)
+        activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None)
+        if activation_norm_params is None:
+            activation_norm_params = types.SimpleNamespace()
+        if not hasattr(activation_norm_params, 'num_filters'):
+            setattr(activation_norm_params, 'num_filters', 128)
+        if not hasattr(activation_norm_params, 'kernel_size'):
+            setattr(activation_norm_params, 'kernel_size', 3)
+        if not hasattr(activation_norm_params, 'activation_norm_type'):
+            setattr(activation_norm_params, 'activation_norm_type', 'sync_batch')
+        if not hasattr(activation_norm_params, 'separate_projection'):
+            setattr(activation_norm_params, 'separate_projection', False)
+        if not hasattr(activation_norm_params, 'activation_norm_params'):
+            activation_norm_params.activation_norm_params = types.SimpleNamespace()
+            activation_norm_params.activation_norm_params.affine = True
+        setattr(activation_norm_params, 'cond_dims', num_labels)
+        if not hasattr(activation_norm_params, 'weight_norm_type'):
+            setattr(activation_norm_params, 'weight_norm_type', weight_norm_type)
+        global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch')
+        use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True)
+        output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0)
+        print(activation_norm_params)
+        self.spade_generator = SPADEGenerator(num_labels,
+                                              out_image_small_side_size,
+                                              image_channels,
+                                              num_filters,
+                                              kernel_size,
+                                              cond_dims,
+                                              activation_norm_params,
+                                              weight_norm_type,
+                                              global_adaptive_norm_type,
+                                              skip_activation_norm,
+                                              use_posenc_in_input_layer,
+                                              self.use_style_encoder,
+                                              output_multiplier)
+        if self.use_style:
+            # Build the encoder.
+            style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
+            if style_enc_cfg is None:
+                style_enc_cfg = types.SimpleNamespace()
+            if not hasattr(style_enc_cfg, 'num_filters'):
+                setattr(style_enc_cfg, 'num_filters', 128)
+            if not hasattr(style_enc_cfg, 'kernel_size'):
+                setattr(style_enc_cfg, 'kernel_size', 3)
+            if not hasattr(style_enc_cfg, 'weight_norm_type'):
+                setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type)
+            setattr(style_enc_cfg, 'input_image_channels', image_channels)
+            setattr(style_enc_cfg, 'style_dims', style_dims)
+            self.style_encoder = StyleEncoder(style_enc_cfg)
+
+        self.z = None
+        print('Done with the SPADE generator initialization.')
+
+    def forward(self, data, random_style=False):
+        r"""SPADE Generator forward.
+
+        Args:
+            data (dict):
+              - images (N x C1 x H x W tensor) : Ground truth images
+              - label (N x C2 x H x W tensor) : Semantic representations
+              - z (N x style_dims tensor): Gaussian random noise
+              - random_style (bool): Whether to sample a random style vector.
+        Returns:
+            (dict):
+              - fake_images (N x 3 x H x W tensor): fake images
+              - mu (N x C1 tensor): mean vectors
+              - logvar (N x C1 tensor): log-variance vectors
+        """
+        if self.use_style_encoder:
+            if random_style:
+                bs = data['label'].size(0)
+                z = torch.randn(
+                    bs, self.style_dims, dtype=torch.float32).cuda()
+                if (data['label'].dtype ==
+                        data['label'].dtype == torch.float16):
+                    z = z.half()
+                mu = None
+                logvar = None
+            else:
+                mu, logvar, z = self.style_encoder(data['images'])
+            if self.use_attribute:
+                data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1)
+            else:
+                data['z'] = z
+        output = self.spade_generator(data)
+        if self.use_style_encoder:
+            output['mu'] = mu
+            output['logvar'] = logvar
+        return output
+
+    def inference(self,
+                  data,
+                  random_style=False,
+                  use_fixed_random_style=False,
+                  keep_original_size=False):
+        r"""Compute results images for a batch of input data and save the
+        results in the specified folder.
+
+        Args:
+            data (dict):
+              - images (N x C1 x H x W tensor) : Ground truth images
+              - label (N x C2 x H x W tensor) : Semantic representations
+              - z (N x style_dims tensor): Gaussian random noise
+            random_style (bool): Whether to sample a random style vector.
+            use_fixed_random_style (bool): Sample random style once and use it
+                for all the remaining inference.
+            keep_original_size (bool): Keep original size of the input.
+        Returns:
+            (dict):
+              - fake_images (N x 3 x H x W tensor): fake images
+              - mu (N x C1 tensor): mean vectors
+              - logvar (N x C1 tensor): log-variance vectors
+        """
+        self.eval()
+        self.spade_generator.eval()
+
+        if self.use_style_encoder:
+            if random_style and self.use_style_encoder:
+                if self.z is None or not use_fixed_random_style:
+                    bs = data['label'].size(0)
+                    z = torch.randn(
+                        bs, self.style_dims, dtype=torch.float32).to('cuda')
+                    if (data['label'].dtype ==
+                            data['label'].dtype ==
+                            torch.float16):
+                        z = z.half()
+                    self.z = z
+                else:
+                    z = self.z
+            else:
+                mu, logvar, z = self.style_encoder(data['images'])
+            data['z'] = z
+
+        output = self.spade_generator(data)
+        output_images = output['fake_images']
+
+        if keep_original_size:
+            height = data['original_h_w'][0][0]
+            width = data['original_h_w'][0][1]
+            output_images = torch.nn.functional.interpolate(
+                output_images, size=[height, width])
+
+        for key in data['key'].keys():
+            if 'segmaps' in key or 'seg_maps' in key:
+                file_names = data['key'][key][0]
+                break
+        for key in data['key'].keys():
+            if 'edgemaps' in key or 'edge_maps' in key:
+                file_names = data['key'][key][0]
+                break
+
+        return output_images, file_names
+
+
+class SPADEGenerator(nn.Module):
+    r"""SPADE Image Generator constructor.
+
+    Args:
+        num_labels (int): Number of different labels.
+        out_image_small_side_size (int): min(width, height)
+        image_channels (int): Num. of channels of the output image.
+        num_filters (int): Base filter numbers.
+        kernel_size (int): Convolution kernel size.
+        style_dims (int): Dimensions of the style code.
+        activation_norm_params (obj): Spatially adaptive normalization param.
+        weight_norm_type (str): Type of weight normalization.
+            ``'none'``, ``'spectral'``, or ``'weight'``.
+        global_adaptive_norm_type (str): Type of normalization in SPADE.
+        skip_activation_norm (bool): If ``True``, applies activation norm to the
+            shortcut connection in residual blocks.
+        use_style_encoder (bool): Whether to use global adaptive norm
+            like conditional batch norm or adaptive instance norm.
+        output_multiplier (float): A positive number multiplied to the output
+    """
+
+    def __init__(self,
+                 num_labels,
+                 out_image_small_side_size,
+                 image_channels,
+                 num_filters,
+                 kernel_size,
+                 style_dims,
+                 activation_norm_params,
+                 weight_norm_type,
+                 global_adaptive_norm_type,
+                 skip_activation_norm,
+                 use_posenc_in_input_layer,
+                 use_style_encoder,
+                 output_multiplier):
+        super(SPADEGenerator, self).__init__()
+        self.output_multiplier = output_multiplier
+        self.use_style_encoder = use_style_encoder
+        self.use_posenc_in_input_layer = use_posenc_in_input_layer
+        self.out_image_small_side_size = out_image_small_side_size
+        self.num_filters = num_filters
+        padding = int(np.ceil((kernel_size - 1.0) / 2))
+        nonlinearity = 'leakyrelu'
+        activation_norm_type = 'spatially_adaptive'
+        base_res2d_block = \
+            functools.partial(Res2dBlock,
+                              kernel_size=kernel_size,
+                              padding=padding,
+                              bias=[True, True, False],
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              activation_norm_params=activation_norm_params,
+                              skip_activation_norm=skip_activation_norm,
+                              nonlinearity=nonlinearity,
+                              order='NACNAC')
+        if self.use_style_encoder:
+            self.fc_0 = LinearBlock(style_dims, 2 * style_dims,
+                                    weight_norm_type=weight_norm_type,
+                                    nonlinearity='relu',
+                                    order='CAN')
+            self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims,
+                                    weight_norm_type=weight_norm_type,
+                                    nonlinearity='relu',
+                                    order='CAN')
+
+            adaptive_norm_params = types.SimpleNamespace()
+            if not hasattr(adaptive_norm_params, 'cond_dims'):
+                setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims)
+            if not hasattr(adaptive_norm_params, 'activation_norm_type'):
+                setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type)
+            if not hasattr(adaptive_norm_params, 'weight_norm_type'):
+                setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type)
+            if not hasattr(adaptive_norm_params, 'separate_projection'):
+                setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection)
+            adaptive_norm_params.activation_norm_params = types.SimpleNamespace()
+            setattr(adaptive_norm_params.activation_norm_params, 'affine',
+                    activation_norm_params.activation_norm_params.affine)
+            base_cbn2d_block = \
+                functools.partial(Conv2dBlock,
+                                  kernel_size=kernel_size,
+                                  stride=1,
+                                  padding=padding,
+                                  bias=True,
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type='adaptive',
+                                  activation_norm_params=adaptive_norm_params,
+                                  nonlinearity=nonlinearity,
+                                  order='NAC')
+        else:
+            base_conv2d_block = \
+                functools.partial(Conv2dBlock,
+                                  kernel_size=kernel_size,
+                                  stride=1,
+                                  padding=padding,
+                                  bias=True,
+                                  weight_norm_type=weight_norm_type,
+                                  nonlinearity=nonlinearity,
+                                  order='NAC')
+        in_num_labels = num_labels
+        in_num_labels += 2 if self.use_posenc_in_input_layer else 0
+        self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters,
+                                  kernel_size=kernel_size, stride=1,
+                                  padding=padding,
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type='none',
+                                  nonlinearity=nonlinearity)
+        if self.use_style_encoder:
+            self.cbn_head_0 = base_cbn2d_block(
+                8 * num_filters, 16 * num_filters)
+        else:
+            self.conv_head_0 = base_conv2d_block(
+                8 * num_filters, 16 * num_filters)
+        self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters)
+        self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters)
+
+        self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters)
+        if self.use_style_encoder:
+            self.cbn_up_0a = base_cbn2d_block(
+                8 * num_filters, 8 * num_filters)
+        else:
+            self.conv_up_0a = base_conv2d_block(
+                8 * num_filters, 8 * num_filters)
+        self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters)
+
+        self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters)
+        if self.use_style_encoder:
+            self.cbn_up_1a = base_cbn2d_block(
+                4 * num_filters, 4 * num_filters)
+        else:
+            self.conv_up_1a = base_conv2d_block(
+                4 * num_filters, 4 * num_filters)
+        self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters)
+        self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters)
+        if self.use_style_encoder:
+            self.cbn_up_2a = base_cbn2d_block(
+                4 * num_filters, 4 * num_filters)
+        else:
+            self.conv_up_2a = base_conv2d_block(
+                4 * num_filters, 4 * num_filters)
+        self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters)
+        self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels,
+                                       5, stride=1, padding=2,
+                                       weight_norm_type=weight_norm_type,
+                                       activation_norm_type='none',
+                                       nonlinearity=nonlinearity,
+                                       order='ANC')
+        self.base = 16
+        if self.out_image_small_side_size == 512:
+            self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
+            self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
+            self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels,
+                                           5, stride=1, padding=2,
+                                           weight_norm_type=weight_norm_type,
+                                           activation_norm_type='none',
+                                           nonlinearity=nonlinearity,
+                                           order='ANC')
+            self.base = 32
+        if self.out_image_small_side_size == 1024:
+            self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
+            self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
+            self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels,
+                                           5, stride=1, padding=2,
+                                           weight_norm_type=weight_norm_type,
+                                           activation_norm_type='none',
+                                           nonlinearity=nonlinearity,
+                                           order='ANC')
+            self.up_4a = base_res2d_block(num_filters, num_filters // 2)
+            self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2)
+            self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels,
+                                            5, stride=1, padding=2,
+                                            weight_norm_type=weight_norm_type,
+                                            activation_norm_type='none',
+                                            nonlinearity=nonlinearity,
+                                            order='ANC')
+            self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest')
+            self.base = 64
+        if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \
+                and self.out_image_small_side_size != 1024:
+            raise ValueError('Generation image size (%d, %d) not supported' %
+                             (self.out_image_small_side_size,
+                              self.out_image_small_side_size))
+        self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest')
+
+        xv, yv = torch.meshgrid(
+            [torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)])
+        self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0)
+        self.xy = self.xy.cuda()
+
+    def forward(self, data):
+        r"""SPADE Generator forward.
+
+        Args:
+            data (dict):
+              - data  (N x C1 x H x W tensor) : Ground truth images.
+              - label (N x C2 x H x W tensor) : Semantic representations.
+              - z (N x style_dims tensor): Gaussian random noise.
+        Returns:
+            output (dict):
+              - fake_images (N x 3 x H x W tensor): Fake images.
+        """
+        seg = data['label']
+
+        if self.use_style_encoder:
+            z = data['z']
+            z = self.fc_0(z)
+            z = self.fc_1(z)
+
+        # The code piece below makes sure that the input size is always 16x16
+        sy = math.floor(seg.size()[2] * 1.0 / self.base)
+        sx = math.floor(seg.size()[3] * 1.0 / self.base)
+
+        in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest')
+        if self.use_posenc_in_input_layer:
+            in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic')
+            in_seg_xy = torch.cat(
+                (in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1)
+        else:
+            in_seg_xy = in_seg
+        # 16x16
+        x = self.head_0(in_seg_xy)
+        if self.use_style_encoder:
+            x = self.cbn_head_0(x, z)
+        else:
+            x = self.conv_head_0(x)
+        x = self.head_1(x, seg)
+        x = self.head_2(x, seg)
+        x = self.nearest_upsample2x(x)
+        # 32x32
+        x = self.up_0a(x, seg)
+        if self.use_style_encoder:
+            x = self.cbn_up_0a(x, z)
+        else:
+            x = self.conv_up_0a(x)
+        x = self.up_0b(x, seg)
+        x = self.nearest_upsample2x(x)
+        # 64x64
+        x = self.up_1a(x, seg)
+        if self.use_style_encoder:
+            x = self.cbn_up_1a(x, z)
+        else:
+            x = self.conv_up_1a(x)
+        x = self.up_1b(x, seg)
+        x = self.nearest_upsample2x(x)
+        # 128x128
+        x = self.up_2a(x, seg)
+        if self.use_style_encoder:
+            x = self.cbn_up_2a(x, z)
+        else:
+            x = self.conv_up_2a(x)
+        x = self.up_2b(x, seg)
+        x = self.nearest_upsample2x(x)
+        # 256x256
+        if self.out_image_small_side_size == 256:
+            x256 = self.conv_img256(x)
+            x = torch.tanh(self.output_multiplier * x256)
+        # 512x512
+        elif self.out_image_small_side_size == 512:
+            x256 = self.conv_img256(x)
+            x256 = self.nearest_upsample2x(x256)
+            x = self.up_3a(x, seg)
+            x = self.up_3b(x, seg)
+            x = self.nearest_upsample2x(x)
+            x512 = self.conv_img512(x)
+            x = torch.tanh(self.output_multiplier * (x256 + x512))
+        # 1024x1024
+        elif self.out_image_small_side_size == 1024:
+            x256 = self.conv_img256(x)
+            x256 = self.nearest_upsample4x(x256)
+            x = self.up_3a(x, seg)
+            x = self.up_3b(x, seg)
+            x = self.nearest_upsample2x(x)
+            x512 = self.conv_img512(x)
+            x512 = self.nearest_upsample2x(x512)
+            x = self.up_4a(x, seg)
+            x = self.up_4b(x, seg)
+            x = self.nearest_upsample2x(x)
+            x1024 = self.conv_img1024(x)
+            x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024))
+        output = dict()
+        output['fake_images'] = x
+        return output
+
+
+class StyleEncoder(nn.Module):
+    r"""Style Encode constructor.
+
+    Args:
+        style_enc_cfg (obj): Style encoder definition file.
+    """
+
+    def __init__(self, style_enc_cfg):
+        super(StyleEncoder, self).__init__()
+        input_image_channels = style_enc_cfg.input_image_channels
+        num_filters = style_enc_cfg.num_filters
+        kernel_size = style_enc_cfg.kernel_size
+        padding = int(np.ceil((kernel_size - 1.0) / 2))
+        style_dims = style_enc_cfg.style_dims
+        weight_norm_type = style_enc_cfg.weight_norm_type
+        activation_norm_type = 'none'
+        nonlinearity = 'leakyrelu'
+        base_conv2d_block = \
+            functools.partial(Conv2dBlock,
+                              kernel_size=kernel_size,
+                              stride=2,
+                              padding=padding,
+                              weight_norm_type=weight_norm_type,
+                              activation_norm_type=activation_norm_type,
+                              # inplace_nonlinearity=True,
+                              nonlinearity=nonlinearity)
+        self.layer1 = base_conv2d_block(input_image_channels, num_filters)
+        self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
+        self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
+        self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
+        self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
+        self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
+        self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
+        self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
+
+    def forward(self, input_x):
+        r"""SPADE Style Encoder forward.
+
+        Args:
+            input_x (N x 3 x H x W tensor): input images.
+        Returns:
+            (tuple):
+              - mu (N x C tensor): Mean vectors.
+              - logvar (N x C tensor): Log-variance vectors.
+              - z (N x C tensor): Style code vectors.
+        """
+        if input_x.size(2) != 256 or input_x.size(3) != 256:
+            input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
+        x = self.layer1(input_x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.layer5(x)
+        x = self.layer6(x)
+        x = x.view(x.size(0), -1)
+        mu = self.fc_mu(x)
+        logvar = self.fc_var(x)
+        std = torch.exp(0.5 * logvar)
+        eps = torch.randn_like(std)
+        z = eps.mul(std) + mu
+        return mu, logvar, z
diff --git a/imaginaire/generators/unit.py b/imaginaire/generators/unit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c09f1b050d59de4940d49c7336ddd2f5928c55c6
--- /dev/null
+++ b/imaginaire/generators/unit.py
@@ -0,0 +1,312 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+
+from torch import nn
+from torch.nn import Upsample as NearestUpsample
+
+from imaginaire.layers import Conv2dBlock, Res2dBlock
+
+
+class Generator(nn.Module):
+    r"""Improved UNIT generator.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super().__init__()
+        self.autoencoder_a = AutoEncoder(**vars(gen_cfg))
+        self.autoencoder_b = AutoEncoder(**vars(gen_cfg))
+
+    def forward(self, data, image_recon=True, cycle_recon=True):
+        r"""UNIT forward function"""
+        images_a = data['images_a']
+        images_b = data['images_b']
+        net_G_output = dict()
+
+        # encode input images into latent code
+        content_a = self.autoencoder_a.content_encoder(images_a)
+        content_b = self.autoencoder_b.content_encoder(images_b)
+
+        # decode (within domain)
+        if image_recon:
+            images_aa = self.autoencoder_a.decoder(content_a)
+            images_bb = self.autoencoder_b.decoder(content_b)
+            net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb))
+
+        # decode (cross domain)
+        images_ba = self.autoencoder_a.decoder(content_b)
+        images_ab = self.autoencoder_b.decoder(content_a)
+
+        # cycle reconstruction
+        if cycle_recon:
+            content_ba = self.autoencoder_a.content_encoder(images_ba)
+            content_ab = self.autoencoder_b.content_encoder(images_ab)
+            images_aba = self.autoencoder_a.decoder(content_ab)
+            images_bab = self.autoencoder_b.decoder(content_ba)
+            net_G_output.update(
+                dict(content_ba=content_ba, content_ab=content_ab,
+                     images_aba=images_aba, images_bab=images_bab))
+
+        # required outputs
+        net_G_output.update(dict(content_a=content_a, content_b=content_b,
+                                 images_ba=images_ba, images_ab=images_ab))
+
+        return net_G_output
+
+    def inference(self, data, a2b=True):
+        r"""UNIT inference.
+
+        Args:
+            data (dict): Training data at the current iteration.
+              - images_a (tensor): Images from domain A.
+              - images_b (tensor): Images from domain B.
+            a2b (bool): If ``True``, translates images from domain A to B,
+                otherwise from B to A.
+        """
+        if a2b:
+            input_key = 'images_a'
+            content_encode = self.autoencoder_a.content_encoder
+            decode = self.autoencoder_b.decoder
+        else:
+            input_key = 'images_b'
+            content_encode = self.autoencoder_b.content_encoder
+            decode = self.autoencoder_a.decoder
+
+        content_images = data[input_key]
+        content = content_encode(content_images)
+        output_images = decode(content)
+        filename = '%s/%s' % (
+            data['key'][input_key]['sequence_name'][0],
+            data['key'][input_key]['filename'][0])
+        filenames = [filename]
+        return output_images, filenames
+
+
+class AutoEncoder(nn.Module):
+    r"""Improved UNIT autoencoder.
+
+    Args:
+        num_filters (int): Base filter numbers.
+        max_num_filters (int): Maximum number of filters in the encoder.
+        num_res_blocks (int): Number of residual blocks at the end of the
+            content encoder.
+        num_downsamples_content (int): Number of times we reduce
+            resolution by 2x2 for the content image.
+        num_image_channels (int): Number of input image channels.
+        content_norm_type (str): Type of activation normalization in the
+            content encoder.
+        decoder_norm_type (str): Type of activation normalization in the
+            decoder.
+        weight_norm_type (str): Type of weight normalization.
+        output_nonlinearity (str): Type of nonlinearity before final output,
+            ``'tanh'`` or ``'none'``.
+        pre_act (bool): If ``True``, uses pre-activation residual blocks.
+        apply_noise (bool): If ``True``, injects Gaussian noise in the decoder.
+    """
+
+    def __init__(self,
+                 num_filters=64,
+                 max_num_filters=256,
+                 num_res_blocks=4,
+                 num_downsamples_content=2,
+                 num_image_channels=3,
+                 content_norm_type='instance',
+                 decoder_norm_type='instance',
+                 weight_norm_type='',
+                 output_nonlinearity='',
+                 pre_act=False,
+                 apply_noise=False,
+                 **kwargs):
+        super().__init__()
+        for key in kwargs:
+            if key != 'type':
+                warnings.warn(
+                    "Generator argument '{}' is not used.".format(key))
+        self.content_encoder = ContentEncoder(num_downsamples_content,
+                                              num_res_blocks,
+                                              num_image_channels,
+                                              num_filters,
+                                              max_num_filters,
+                                              'reflect',
+                                              content_norm_type,
+                                              weight_norm_type,
+                                              'relu',
+                                              pre_act)
+        self.decoder = Decoder(num_downsamples_content,
+                               num_res_blocks,
+                               self.content_encoder.output_dim,
+                               num_image_channels,
+                               'reflect',
+                               decoder_norm_type,
+                               weight_norm_type,
+                               'relu',
+                               output_nonlinearity,
+                               pre_act,
+                               apply_noise)
+
+    def forward(self, images):
+        r"""Reconstruct an image.
+
+        Args:
+            images (Tensor): Input images.
+        Returns:
+            images_recon (Tensor): Reconstructed images.
+        """
+        content = self.content_encoder(images)
+        images_recon = self.decoder(content)
+        return images_recon
+
+
+class ContentEncoder(nn.Module):
+    r"""Improved UNIT encoder. The network consists of:
+
+    - input layers
+    - $(num_downsamples) convolutional blocks
+    - $(num_res_blocks) residual blocks.
+    - output layer.
+
+    Args:
+        num_downsamples (int): Number of times we reduce
+            resolution by 2x2.
+        num_res_blocks (int): Number of residual blocks at the end of the
+            content encoder.
+        num_image_channels (int): Number of input image channels.
+        num_filters (int): Base filter numbers.
+        max_num_filters (int): Maximum number of filters in the encoder.
+        padding_mode (string): Type of padding.
+        activation_norm_type (str): Type of activation normalization.
+        weight_norm_type (str): Type of weight normalization.
+        nonlinearity (str): Type of nonlinear activation function.
+        pre_act (bool): If ``True``, uses pre-activation residual blocks.
+    """
+
+    def __init__(self,
+                 num_downsamples,
+                 num_res_blocks,
+                 num_image_channels,
+                 num_filters,
+                 max_num_filters,
+                 padding_mode,
+                 activation_norm_type,
+                 weight_norm_type,
+                 nonlinearity,
+                 pre_act=False):
+        super().__init__()
+        conv_params = dict(padding_mode=padding_mode,
+                           activation_norm_type=activation_norm_type,
+                           weight_norm_type=weight_norm_type,
+                           nonlinearity=nonlinearity)
+        # Whether or not it is safe to use inplace nonlinear activation.
+        if not pre_act or (activation_norm_type != '' and
+                           activation_norm_type != 'none'):
+            conv_params['inplace_nonlinearity'] = True
+
+        # The order of operations in residual blocks.
+        order = 'pre_act' if pre_act else 'CNACNA'
+
+        model = []
+        model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3,
+                              **conv_params)]
+
+        # Downsampling blocks.
+        for i in range(num_downsamples):
+            num_filters_prev = num_filters
+            num_filters = min(num_filters * 2, max_num_filters)
+            model += [Conv2dBlock(num_filters_prev, num_filters, 4, 2, 1,
+                                  **conv_params)]
+
+        # Residual blocks.
+        for _ in range(num_res_blocks):
+            model += [Res2dBlock(num_filters, num_filters,
+                                 **conv_params,
+                                 order=order)]
+        self.model = nn.Sequential(*model)
+        self.output_dim = num_filters
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input image.
+        """
+        return self.model(x)
+
+
+class Decoder(nn.Module):
+    r"""Improved UNIT decoder. The network consists of:
+
+    - $(num_res_blocks) residual blocks.
+    - $(num_upsamples) residual blocks or convolutional blocks
+    - output layer.
+
+    Args:
+        num_upsamples (int): Number of times we increase resolution by 2x2.
+        num_res_blocks (int): Number of residual blocks.
+        num_filters (int): Base filter numbers.
+        num_image_channels (int): Number of input image channels.
+        padding_mode (string): Type of padding.
+        activation_norm_type (str): Type of activation normalization.
+        weight_norm_type (str): Type of weight normalization.
+        nonlinearity (str): Type of nonlinear activation function.
+        output_nonlinearity (str): Type of nonlinearity before final output,
+            ``'tanh'`` or ``'none'``.
+        pre_act (bool): If ``True``, uses pre-activation residual blocks.
+        apply_noise (bool): If ``True``, injects Gaussian noise.
+    """
+
+    def __init__(self,
+                 num_upsamples,
+                 num_res_blocks,
+                 num_filters,
+                 num_image_channels,
+                 padding_mode,
+                 activation_norm_type,
+                 weight_norm_type,
+                 nonlinearity,
+                 output_nonlinearity,
+                 pre_act=False,
+                 apply_noise=False):
+        super().__init__()
+
+        conv_params = dict(padding_mode=padding_mode,
+                           nonlinearity=nonlinearity,
+                           inplace_nonlinearity=True,
+                           apply_noise=apply_noise,
+                           weight_norm_type=weight_norm_type,
+                           activation_norm_type=activation_norm_type)
+
+        # The order of operations in residual blocks.
+        order = 'pre_act' if pre_act else 'CNACNA'
+
+        # Residual blocks.
+        self.decoder = nn.ModuleList()
+        for _ in range(num_res_blocks):
+            self.decoder += [Res2dBlock(num_filters, num_filters,
+                                        **conv_params,
+                                        order=order)]
+
+        # Convolutional blocks with upsampling.
+        for i in range(num_upsamples):
+            self.decoder += [NearestUpsample(scale_factor=2)]
+            self.decoder += [Conv2dBlock(num_filters, num_filters // 2,
+                                         5, 1, 2, **conv_params)]
+            num_filters //= 2
+        self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3,
+                                     nonlinearity=output_nonlinearity,
+                                     padding_mode=padding_mode)]
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Content embedding of the content image.
+        """
+        for block in self.decoder:
+            x = block(x)
+        return x
diff --git a/imaginaire/generators/vid2vid.py b/imaginaire/generators/vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..78262debc9d2ea6106a0816b53c4c73c5a5a6053
--- /dev/null
+++ b/imaginaire/generators/vid2vid.py
@@ -0,0 +1,481 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.generators.fs_vid2vid import LabelEmbedder
+from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
+from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels,
+                                               resample)
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+                                   get_paired_input_label_channel_number)
+from imaginaire.utils.init_weight import weights_init
+
+
+class BaseNetwork(nn.Module):
+    r"""vid2vid generator."""
+
+    def __init__(self):
+        super(BaseNetwork, self).__init__()
+
+    def get_num_filters(self, num_downsamples):
+        r"""Get the number of filters at current layer.
+
+        Args:
+            num_downsamples (int) : How many downsamples at current layer.
+        Returns:
+            output (int) : Number of filters.
+        """
+        return min(self.max_num_filters,
+                   self.num_filters * (2 ** num_downsamples))
+
+
+class Generator(BaseNetwork):
+    r"""vid2vid generator constructor.
+
+    Args:
+        gen_cfg (obj): Generator definition part of the yaml config file.
+        data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        super().__init__()
+        self.gen_cfg = gen_cfg
+        self.data_cfg = data_cfg
+        self.num_frames_G = data_cfg.num_frames_G
+        # Number of residual blocks in generator.
+        self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7)
+        # Number of downsamplings for previous frame.
+        self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4)
+        # Number of filters in the first layer.
+        self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32)
+        self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024)
+        self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3)
+        padding = kernel_size // 2
+
+        # For pose dataset.
+        self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
+        if self.is_pose_data:
+            pose_cfg = data_cfg.for_pose_dataset
+            self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
+            self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
+                                              False)
+
+        # Input data params.
+        num_input_channels = get_paired_input_label_channel_number(data_cfg)
+        num_img_channels = get_paired_input_image_channel_number(data_cfg)
+        aug_cfg = data_cfg.val.augmentations
+        if hasattr(aug_cfg, 'center_crop_h_w'):
+            crop_h_w = aug_cfg.center_crop_h_w
+        elif hasattr(aug_cfg, 'resize_h_w'):
+            crop_h_w = aug_cfg.resize_h_w
+        else:
+            raise ValueError('Need to specify output size.')
+        crop_h, crop_w = crop_h_w.split(',')
+        crop_h, crop_w = int(crop_h), int(crop_w)
+        # Spatial size at the bottle neck of generator.
+        self.sh = crop_h // (2 ** num_layers)
+        self.sw = crop_w // (2 ** num_layers)
+
+        # Noise vector dimension.
+        self.z_dim = getattr(gen_cfg, 'style_dims', 256)
+        self.use_segmap_as_input = \
+            getattr(gen_cfg, 'use_segmap_as_input', False)
+
+        # Label / image embedding network.
+        self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None)
+        self.use_embed = getattr(emb_cfg, 'use_embed', 'True')
+        self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5)
+        if self.use_embed:
+            self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels)
+
+        # Flow network.
+        self.flow_cfg = flow_cfg = gen_cfg.flow
+        # Use SPADE to combine warped and hallucinated frames instead of
+        # linear combination.
+        self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
+        # Number of layers to perform multi-spade combine.
+        self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
+                                              'num_layers', 3)
+        # At beginning of training, only train an image generator.
+        self.temporal_initialized = False
+        # Whether to output hallucinated frame (when training temporal network)
+        # for additional loss.
+        self.generate_raw_output = False
+
+        # Image generation network.
+        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
+        activation_norm_type = gen_cfg.activation_norm_type
+        activation_norm_params = gen_cfg.activation_norm_params
+        if self.use_embed and \
+                not hasattr(activation_norm_params, 'num_filters'):
+            activation_norm_params.num_filters = 0
+        nonlinearity = 'leakyrelu'
+
+        self.base_res_block = base_res_block = partial(
+            Res2dBlock, kernel_size=kernel_size, padding=padding,
+            weight_norm_type=weight_norm_type,
+            activation_norm_type=activation_norm_type,
+            activation_norm_params=activation_norm_params,
+            nonlinearity=nonlinearity, order='NACNAC')
+
+        # Upsampling residual blocks.
+        for i in range(num_layers, -1, -1):
+            activation_norm_params.cond_dims = self.get_cond_dims(i)
+            activation_norm_params.partial = self.get_partial(
+                i) if hasattr(self, 'get_partial') else False
+            layer = base_res_block(self.get_num_filters(i + 1),
+                                   self.get_num_filters(i))
+            setattr(self, 'up_%d' % i, layer)
+
+        # Final conv layer.
+        self.conv_img = Conv2dBlock(num_filters, num_img_channels,
+                                    kernel_size, padding=padding,
+                                    nonlinearity=nonlinearity, order='AC')
+
+        num_filters = min(self.max_num_filters,
+                          num_filters * (2 ** (self.num_layers + 1)))
+        if self.use_segmap_as_input:
+            self.fc = Conv2dBlock(num_input_channels, num_filters,
+                                  kernel_size=3, padding=1)
+        else:
+            self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw)
+
+        # Misc.
+        self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+        self.upsample = partial(F.interpolate, scale_factor=2)
+        self.init_temporal_network()
+
+    def forward(self, data):
+        r"""vid2vid generator forward.
+
+        Args:
+           data (dict) : Dictionary of input data.
+        Returns:
+           output (dict) : Dictionary of output data.
+        """
+        label = data['label']
+        label_prev, img_prev = data['prev_labels'], data['prev_images']
+        is_first_frame = img_prev is None
+        z = getattr(data, 'z', None)
+        bs, _, h, w = label.size()
+
+        if self.is_pose_data:
+            label, label_prev = extract_valid_pose_labels(
+                [label, label_prev], self.pose_type, self.remove_face_labels)
+
+        # Get SPADE conditional maps by embedding current label input.
+        cond_maps_now = self.get_cond_maps(label, self.label_embedding)
+
+        # Input to the generator will either be noise/segmentation map (for
+        # first frame) or encoded previous frame (for subsequent frames).
+        if is_first_frame:
+            # First frame in the sequence, start from scratch.
+            if self.use_segmap_as_input:
+                x_img = F.interpolate(label, size=(self.sh, self.sw))
+                x_img = self.fc(x_img)
+            else:
+                if z is None:
+                    z = torch.randn(bs, self.z_dim, dtype=label.dtype,
+                                    device=label.get_device()).fill_(0)
+                x_img = self.fc(z).view(bs, -1, self.sh, self.sw)
+
+            # Upsampling layers.
+            for i in range(self.num_layers, self.num_downsamples_img, -1):
+                j = min(self.num_downsamples_embed, i)
+                x_img = getattr(self, 'up_' + str(i))(x_img, *cond_maps_now[j])
+                x_img = self.upsample(x_img)
+        else:
+            # Not the first frame, will encode the previous frame and feed to
+            # the generator.
+            x_img = self.down_first(img_prev[:, -1])
+
+            # Get label embedding for the previous frame.
+            cond_maps_prev = self.get_cond_maps(label_prev[:, -1],
+                                                self.label_embedding)
+
+            # Downsampling layers.
+            for i in range(self.num_downsamples_img + 1):
+                j = min(self.num_downsamples_embed, i)
+                x_img = getattr(self, 'down_' + str(i))(x_img,
+                                                        *cond_maps_prev[j])
+                if i != self.num_downsamples_img:
+                    x_img = self.downsample(x_img)
+
+            # Resnet blocks.
+            j = min(self.num_downsamples_embed, self.num_downsamples_img + 1)
+            for i in range(self.num_res_blocks):
+                cond_maps = cond_maps_prev[j] if i < self.num_res_blocks // 2 \
+                    else cond_maps_now[j]
+                x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps)
+
+        flow = mask = img_warp = None
+
+        num_frames_G = self.num_frames_G
+        # Whether to warp the previous frame or not.
+        warp_prev = self.temporal_initialized and not is_first_frame and \
+            label_prev.shape[1] == num_frames_G - 1
+        if warp_prev:
+            # Estimate flow & mask.
+            label_concat = torch.cat([label_prev.view(bs, -1, h, w),
+                                      label], dim=1)
+            img_prev_concat = img_prev.view(bs, -1, h, w)
+            flow, mask = self.flow_network_temp(label_concat, img_prev_concat)
+            img_warp = resample(img_prev[:, -1], flow)
+            if self.spade_combine:
+                # if using SPADE combine, integrate the warped image (and
+                # occlusion mask) into conditional inputs for SPADE.
+                img_embed = torch.cat([img_warp, mask], dim=1)
+                cond_maps_img = self.get_cond_maps(img_embed,
+                                                   self.img_prev_embedding)
+                x_raw_img = None
+
+        # Main image generation branch.
+        for i in range(self.num_downsamples_img, -1, -1):
+            # Get SPADE conditional inputs.
+            j = min(i, self.num_downsamples_embed)
+            cond_maps = cond_maps_now[j]
+
+            # For raw output generation.
+            if self.generate_raw_output:
+                if i >= self.num_multi_spade_layers - 1:
+                    x_raw_img = x_img
+                if i < self.num_multi_spade_layers:
+                    x_raw_img = self.one_up_conv_layer(x_raw_img, cond_maps, i)
+
+            # For final output.
+            if warp_prev and i < self.num_multi_spade_layers:
+                cond_maps += cond_maps_img[j]
+            x_img = self.one_up_conv_layer(x_img, cond_maps, i)
+
+        # Final conv layer.
+        img_final = torch.tanh(self.conv_img(x_img))
+
+        img_raw = None
+        if self.spade_combine and self.generate_raw_output:
+            img_raw = torch.tanh(self.conv_img(x_raw_img))
+        if warp_prev and not self.spade_combine:
+            img_raw = img_final
+            img_final = img_final * mask + img_warp * (1 - mask)
+
+        output = dict()
+        output['fake_images'] = img_final
+        output['fake_flow_maps'] = flow
+        output['fake_occlusion_masks'] = mask
+        output['fake_raw_images'] = img_raw
+        output['warped_images'] = img_warp
+        return output
+
+    def one_up_conv_layer(self, x, encoded_label, i):
+        r"""One residual block layer in the main branch.
+
+        Args:
+           x (4D tensor) : Current feature map.
+           encoded_label (list of tensors) : Encoded input label maps.
+           i (int) : Layer index.
+        Returns:
+           x (4D tensor) : Output feature map.
+        """
+        layer = getattr(self, 'up_' + str(i))
+        x = layer(x, *encoded_label)
+        if i != 0:
+            x = self.upsample(x)
+        return x
+
+    def init_temporal_network(self, cfg_init=None):
+        r"""When starting training multiple frames, initialize the
+        downsampling network and flow network.
+
+        Args:
+            cfg_init (dict) : Weight initialization config.
+        """
+        # Number of image downsamplings for the previous frame.
+        num_downsamples_img = self.num_downsamples_img
+        # Number of residual blocks for the previous frame.
+        self.num_res_blocks = int(
+            np.ceil((self.num_layers - num_downsamples_img) / 2.0) * 2)
+
+        # First conv layer.
+        num_img_channels = get_paired_input_image_channel_number(self.data_cfg)
+        self.down_first = \
+            Conv2dBlock(num_img_channels,
+                        self.num_filters, self.kernel_size,
+                        padding=self.kernel_size // 2)
+        if cfg_init is not None:
+            self.down_first.apply(weights_init(cfg_init.type, cfg_init.gain))
+
+        # Downsampling residual blocks.
+        activation_norm_params = self.gen_cfg.activation_norm_params
+        for i in range(num_downsamples_img + 1):
+            activation_norm_params.cond_dims = self.get_cond_dims(i)
+            layer = self.base_res_block(self.get_num_filters(i),
+                                        self.get_num_filters(i + 1))
+            if cfg_init is not None:
+                layer.apply(weights_init(cfg_init.type, cfg_init.gain))
+            setattr(self, 'down_%d' % i, layer)
+
+        # Additional residual blocks.
+        res_ch = self.get_num_filters(num_downsamples_img + 1)
+        activation_norm_params.cond_dims = \
+            self.get_cond_dims(num_downsamples_img + 1)
+        for i in range(self.num_res_blocks):
+            layer = self.base_res_block(res_ch, res_ch)
+            if cfg_init is not None:
+                layer.apply(weights_init(cfg_init.type, cfg_init.gain))
+            setattr(self, 'res_%d' % i, layer)
+
+        # Flow network.
+        flow_cfg = self.flow_cfg
+        self.temporal_initialized = True
+        self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output',
+                                           False) and self.spade_combine
+        self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg)
+        if cfg_init is not None:
+            self.flow_network_temp.apply(weights_init(cfg_init.type,
+                                                      cfg_init.gain))
+
+        self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
+        if self.spade_combine:
+            emb_cfg = flow_cfg.multi_spade_combine.embed
+            num_img_channels = get_paired_input_image_channel_number(
+                self.data_cfg)
+            self.img_prev_embedding = LabelEmbedder(emb_cfg,
+                                                    num_img_channels + 1)
+            if cfg_init is not None:
+                self.img_prev_embedding.apply(weights_init(cfg_init.type,
+                                                           cfg_init.gain))
+
+    def get_cond_dims(self, num_downs=0):
+        r"""Get the dimensions of conditional inputs.
+
+        Args:
+           num_downs (int) : How many downsamples at current layer.
+        Returns:
+           ch (list) : List of dimensions.
+        """
+        if not self.use_embed:
+            ch = [self.num_input_channels]
+        else:
+            num_filters = getattr(self.emb_cfg, 'num_filters', 32)
+            num_downs = min(num_downs, self.num_downsamples_embed)
+            ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))]
+            if (num_downs < self.num_multi_spade_layers):
+                ch = ch * 2
+        return ch
+
+    def get_cond_maps(self, label, embedder):
+        r"""Get the conditional inputs.
+
+        Args:
+           label (4D tensor) : Input label tensor.
+           embedder (obj) : Embedding network.
+        Returns:
+           cond_maps (list) : List of conditional inputs.
+        """
+        if not self.use_embed:
+            return [label] * (self.num_layers + 1)
+        embedded_label = embedder(label)
+        cond_maps = [embedded_label]
+        cond_maps = [[m[i] for m in cond_maps] for i in
+                     range(len(cond_maps[0]))]
+        return cond_maps
+
+
+class FlowGenerator(BaseNetwork):
+    r"""Flow generator constructor.
+
+    Args:
+       flow_cfg (obj): Flow definition part of the yaml config file.
+       data_cfg (obj): Data definition part of the yaml config file.
+    """
+
+    def __init__(self, flow_cfg, data_cfg):
+        super().__init__()
+        num_input_channels = get_paired_input_label_channel_number(data_cfg)
+        num_prev_img_channels = get_paired_input_image_channel_number(data_cfg)
+        num_frames = data_cfg.num_frames_G  # Num. of input frames.
+
+        self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32)
+        self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
+        num_downsamples = getattr(flow_cfg, 'num_downsamples', 5)
+        kernel_size = getattr(flow_cfg, 'kernel_size', 3)
+        padding = kernel_size // 2
+        self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6)
+        # Multiplier on the flow output.
+        self.flow_output_multiplier = getattr(flow_cfg,
+                                              'flow_output_multiplier', 20)
+
+        activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
+                                       'sync_batch')
+        weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')
+
+        base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size,
+                                  padding=padding,
+                                  weight_norm_type=weight_norm_type,
+                                  activation_norm_type=activation_norm_type,
+                                  nonlinearity='leakyrelu')
+
+        # Will downsample the labels and prev frames separately, then combine.
+        down_lbl = [base_conv_block(num_input_channels * num_frames,
+                                    num_filters)]
+        down_img = [base_conv_block(num_prev_img_channels * (num_frames - 1),
+                                    num_filters)]
+        for i in range(num_downsamples):
+            down_lbl += [base_conv_block(self.get_num_filters(i),
+                                         self.get_num_filters(i + 1),
+                                         stride=2)]
+            down_img += [base_conv_block(self.get_num_filters(i),
+                                         self.get_num_filters(i + 1),
+                                         stride=2)]
+
+        # Resnet blocks.
+        res_flow = []
+        ch = self.get_num_filters(num_downsamples)
+        for i in range(self.num_res_blocks):
+            res_flow += [
+                Res2dBlock(ch, ch, kernel_size, padding=padding,
+                           weight_norm_type=weight_norm_type,
+                           activation_norm_type=activation_norm_type,
+                           order='CNACN')]
+
+        # Upsample.
+        up_flow = []
+        for i in reversed(range(num_downsamples)):
+            up_flow += [nn.Upsample(scale_factor=2),
+                        base_conv_block(self.get_num_filters(i + 1),
+                                        self.get_num_filters(i))]
+
+        conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)]
+        conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding,
+                                 nonlinearity='sigmoid')]
+
+        self.down_lbl = nn.Sequential(*down_lbl)
+        self.down_img = nn.Sequential(*down_img)
+        self.res_flow = nn.Sequential(*res_flow)
+        self.up_flow = nn.Sequential(*up_flow)
+        self.conv_flow = nn.Sequential(*conv_flow)
+        self.conv_mask = nn.Sequential(*conv_mask)
+
+    def forward(self, label, img_prev):
+        r"""Flow generator forward.
+
+        Args:
+           label (4D tensor) : Input label tensor.
+           img_prev (4D tensor) : Previously generated image tensors.
+        Returns:
+            (tuple):
+              - flow (4D tensor) : Generated flow map.
+              - mask (4D tensor) : Generated occlusion mask.
+        """
+        downsample = self.down_lbl(label) + self.down_img(img_prev)
+        res = self.res_flow(downsample)
+        flow_feat = self.up_flow(res)
+        flow = self.conv_flow(flow_feat) * self.flow_output_multiplier
+        mask = self.conv_mask(flow_feat)
+        return flow, mask
diff --git a/imaginaire/generators/wc_vid2vid.py b/imaginaire/generators/wc_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d8139a3c05140a567ef202eccd1d9209ba681b4
--- /dev/null
+++ b/imaginaire/generators/wc_vid2vid.py
@@ -0,0 +1,354 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torchvision import transforms
+
+from imaginaire.config import Config
+from imaginaire.generators.vid2vid import Generator as Vid2VidGenerator
+from imaginaire.model_utils.fs_vid2vid import resample
+from imaginaire.model_utils.wc_vid2vid.render import SplatRenderer
+from imaginaire.utils.trainer import (get_model_optimizer_and_scheduler,
+                                      get_trainer)
+from imaginaire.utils.visualization import tensor2im
+
+
+class Generator(Vid2VidGenerator):
+    r"""world consistent vid2vid generator constructor.
+
+    Args:
+       gen_cfg (obj): Generator definition part of the yaml config file.
+       data_cfg (obj): Data definition part of the yaml config file
+    """
+
+    def __init__(self, gen_cfg, data_cfg):
+        # Guidance options.
+        self.guidance_cfg = gen_cfg.guidance
+        self.guidance_only_with_flow = getattr(
+            self.guidance_cfg, 'only_with_flow', False)
+        self.guidance_partial_conv = getattr(
+            self.guidance_cfg, 'partial_conv', False)
+
+        # Splatter for guidance.
+        self.renderer = SplatRenderer()
+        self.reset_renderer()
+
+        # Single image model.
+        self.single_image_model = None
+
+        # Initialize the rest same as vid2vid.
+        super().__init__(gen_cfg, data_cfg)
+
+    def _init_single_image_model(self, load_weights=True):
+        r"""Load single image model, if any."""
+        if self.single_image_model is None and \
+                hasattr(self.gen_cfg, 'single_image_model'):
+            print('Using single image model...')
+            single_image_cfg = Config(self.gen_cfg.single_image_model.config)
+
+            # Init model.
+            net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
+                get_model_optimizer_and_scheduler(single_image_cfg)
+
+            # Init trainer and load checkpoint.
+            trainer = get_trainer(single_image_cfg, net_G, net_D,
+                                  opt_G, opt_D,
+                                  sch_G, sch_D,
+                                  None, None)
+            if load_weights:
+                print('Loading single image model checkpoint')
+                single_image_ckpt = self.gen_cfg.single_image_model.checkpoint
+                trainer.load_checkpoint(single_image_cfg, single_image_ckpt)
+                print('Loaded single image model checkpoint')
+
+            self.single_image_model = net_G.module
+            self.single_image_model_z = None
+
+    def reset_renderer(self, is_flipped_input=False):
+        r"""Reset the renderer.
+        Args:
+            is_flipped_input (bool): Is the input sequence left-right flipped?
+        """
+        self.renderer.reset()
+        self.is_flipped_input = is_flipped_input
+        self.renderer_num_forwards = 0
+        self.single_image_model_z = None
+
+    def renderer_update_point_cloud(self, image, point_info):
+        r"""Update the renderer's color dictionary."""
+        if point_info is None or len(point_info) == 0:
+            return
+        # print('Updating the renderer.')
+        _, _, h, w = image.size()
+
+        # Renderer expects (h, w, c) [0-255] RGB image.
+        if isinstance(image, torch.Tensor):
+            image = tensor2im(image.detach())[0]
+
+        # Flip this image to correspond to SfM camera pose.
+        if self.is_flipped_input:
+            image = np.fliplr(image).copy()
+
+        self.renderer.update_point_cloud(image, point_info)
+        self.renderer_num_forwards += 1
+
+    def get_guidance_images_and_masks(self, unprojection):
+        r"""Do stuff."""
+
+        resolution = 'w1024xh512'
+        point_info = unprojection[resolution]
+
+        w, h = resolution.split('x')
+        w, h = int(w[1:]), int(h[1:])
+
+        # This returns guidance image in [0-255] RGB.
+        # We will convert it into Tensor repr. below.
+        guidance_image, guidance_mask = self.renderer.render_image(
+            point_info, w, h, return_mask=True)
+
+        # If mask is None, there is no guidance.
+        # print(np.sum(guidance_mask), guidance_mask.size)
+        # if np.sum(guidance_mask) == 0:
+        #     return None, point_info
+
+        # Flip guidance image and guidance mask if needed.
+        if self.is_flipped_input:
+            guidance_image = np.fliplr(guidance_image).copy()
+            guidance_mask = np.fliplr(guidance_mask).copy()
+
+        # Go from (h, w, c) to (1, c, h, w).
+        # Convert guidance image to Tensor.
+        guidance_image = (transforms.ToTensor()(guidance_image) - 0.5) * 2
+        guidance_mask = transforms.ToTensor()(guidance_mask)
+        guidance = torch.cat((guidance_image, guidance_mask), dim=0)
+        guidance = guidance.unsqueeze(0).cuda()
+
+        # Save guidance at all resolutions.
+        guidance_images_and_masks = guidance
+
+        return guidance_images_and_masks, point_info
+
+    def forward(self, data):
+        r"""vid2vid generator forward.
+        Args:
+           data (dict) : Dictionary of input data.
+        Returns:
+           output (dict) : Dictionary of output data.
+        """
+        self._init_single_image_model()
+
+        label = data['label']
+        unprojection = data['unprojection']
+        label_prev, img_prev = data['prev_labels'], data['prev_images']
+        is_first_frame = img_prev is None
+        z = getattr(data, 'z', None)
+        bs, _, h, w = label.size()
+
+        # Whether to warp the previous frame or not.
+        flow = mask = img_warp = None
+        warp_prev = self.temporal_initialized and not is_first_frame and \
+            label_prev.shape[1] == self.num_frames_G - 1
+
+        # Get guidance images and masks.
+        guidance_images_and_masks, point_info = None, None
+        if unprojection is not None:
+            guidance_images_and_masks, point_info = \
+                self.get_guidance_images_and_masks(unprojection)
+
+        # Get SPADE conditional maps by embedding current label input.
+        cond_maps_now = self.get_cond_maps(label, self.label_embedding)
+
+        # Use single image model, if flow features are not available.
+        # Guidance features are used whenever flow features are available.
+        if self.single_image_model is not None and not warp_prev:
+            # Get z vector for single image model.
+            if self.single_image_model_z is None:
+                bs = data['label'].size(0)
+                z = torch.randn(bs, self.single_image_model.style_dims,
+                                dtype=torch.float32).cuda()
+                if data['label'].dtype == torch.float16:
+                    z = z.half()
+                self.single_image_model_z = z
+
+            # Get output image.
+            data['z'] = self.single_image_model_z
+            self.single_image_model.eval()
+            with torch.no_grad():
+                output = self.single_image_model.spade_generator(data)
+            img_final = output['fake_images'].detach()
+            fake_images_source = 'pretrained'
+        else:
+            # Input to the generator will either be noise/segmentation map (for
+            # first frame) or encoded previous frame (for subsequent frames).
+            if is_first_frame:
+                # First frame in the sequence, start from scratch.
+                if self.use_segmap_as_input:
+                    x_img = F.interpolate(label, size=(self.sh, self.sw))
+                    x_img = self.fc(x_img)
+                else:
+                    if z is None:
+                        z = torch.randn(bs, self.z_dim, dtype=label.dtype,
+                                        device=label.get_device()).fill_(0)
+                    x_img = self.fc(z).view(bs, -1, self.sh, self.sw)
+
+                # Upsampling layers.
+                for i in range(self.num_layers, self.num_downsamples_img, -1):
+                    j = min(self.num_downsamples_embed, i)
+                    x_img = getattr(self, 'up_' + str(i)
+                                    )(x_img, *cond_maps_now[j])
+                    x_img = self.upsample(x_img)
+            else:
+                # Not the first frame, will encode the previous frame and feed
+                # to the generator.
+                x_img = self.down_first(img_prev[:, -1])
+
+                # Get label embedding for the previous frame.
+                cond_maps_prev = self.get_cond_maps(label_prev[:, -1],
+                                                    self.label_embedding)
+
+                # Downsampling layers.
+                for i in range(self.num_downsamples_img + 1):
+                    j = min(self.num_downsamples_embed, i)
+                    x_img = getattr(self, 'down_' + str(i))(x_img,
+                                                            *cond_maps_prev[j])
+                    if i != self.num_downsamples_img:
+                        x_img = self.downsample(x_img)
+
+                # Resnet blocks.
+                j = min(self.num_downsamples_embed,
+                        self.num_downsamples_img + 1)
+                for i in range(self.num_res_blocks):
+                    cond_maps = cond_maps_prev[j] if \
+                        i < self.num_res_blocks // 2 else cond_maps_now[j]
+                    x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps)
+
+            # Optical flow warped image features.
+            if warp_prev:
+                # Estimate flow & mask.
+                label_concat = torch.cat([label_prev.view(bs, -1, h, w),
+                                          label], dim=1)
+                img_prev_concat = img_prev.view(bs, -1, h, w)
+                flow, mask = self.flow_network_temp(
+                    label_concat, img_prev_concat)
+                img_warp = resample(img_prev[:, -1], flow)
+                if self.spade_combine:
+                    # if using SPADE combine, integrate the warped image (and
+                    # occlusion mask) into conditional inputs for SPADE.
+                    img_embed = torch.cat([img_warp, mask], dim=1)
+                    cond_maps_img = self.get_cond_maps(img_embed,
+                                                       self.img_prev_embedding)
+                    x_raw_img = None
+
+            # Main image generation branch.
+            for i in range(self.num_downsamples_img, -1, -1):
+                # Get SPADE conditional inputs.
+                j = min(i, self.num_downsamples_embed)
+                cond_maps = cond_maps_now[j]
+
+                # For raw output generation.
+                if self.generate_raw_output:
+                    if i >= self.num_multi_spade_layers - 1:
+                        x_raw_img = x_img
+                    if i < self.num_multi_spade_layers:
+                        x_raw_img = self.one_up_conv_layer(
+                            x_raw_img, cond_maps, i)
+
+                # Add flow and guidance features.
+                if warp_prev:
+                    if i < self.num_multi_spade_layers:
+                        # Add flow.
+                        cond_maps += cond_maps_img[j]
+                        # Add guidance.
+                        if guidance_images_and_masks is not None:
+                            cond_maps += [guidance_images_and_masks]
+                    elif not self.guidance_only_with_flow:
+                        # Add guidance if it is to be applied to every layer.
+                        if guidance_images_and_masks is not None:
+                            cond_maps += [guidance_images_and_masks]
+
+                x_img = self.one_up_conv_layer(x_img, cond_maps, i)
+
+            # Final conv layer.
+            img_final = torch.tanh(self.conv_img(x_img))
+            fake_images_source = 'in_training'
+
+        # Update the point cloud color dict of renderer.
+        self.renderer_update_point_cloud(img_final, point_info)
+
+        output = dict()
+        output['fake_images'] = img_final
+        output['fake_flow_maps'] = flow
+        output['fake_occlusion_masks'] = mask
+        output['fake_raw_images'] = None
+        output['warped_images'] = img_warp
+        output['guidance_images_and_masks'] = guidance_images_and_masks
+        output['fake_images_source'] = fake_images_source
+        return output
+
+    def get_cond_dims(self, num_downs=0):
+        r"""Get the dimensions of conditional inputs.
+        Args:
+           num_downs (int) : How many downsamples at current layer.
+        Returns:
+           ch (list) : List of dimensions.
+        """
+        if not self.use_embed:
+            ch = [self.num_input_channels]
+        else:
+            num_filters = getattr(self.emb_cfg, 'num_filters', 32)
+            num_downs = min(num_downs, self.num_downsamples_embed)
+            ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))]
+            if (num_downs < self.num_multi_spade_layers):
+                ch = ch * 2
+                # Also add guidance (RGB + mask = 4 channels, or 3 if partial).
+                if self.guidance_partial_conv:
+                    ch.append(3)
+                else:
+                    ch.append(4)
+            elif not self.guidance_only_with_flow:
+                if self.guidance_partial_conv:
+                    ch.append(3)
+                else:
+                    ch.append(4)
+        return ch
+
+    def get_partial(self, num_downs=0):
+        r"""Get if convs should be partial or not.
+        Args:
+           num_downs (int) : How many downsamples at current layer.
+        Returns:
+           partial (list) : List of boolean partial or not.
+        """
+        partial = [False]
+        if (num_downs < self.num_multi_spade_layers):
+            partial = partial * 2
+            # Also add guidance (RGB + mask = 4 channels, or 3 if partial).
+            if self.guidance_partial_conv:
+                partial.append(True)
+            else:
+                partial.append(False)
+        elif not self.guidance_only_with_flow:
+            if self.guidance_partial_conv:
+                partial.append(True)
+            else:
+                partial.append(False)
+        return partial
+
+    def get_cond_maps(self, label, embedder):
+        r"""Get the conditional inputs.
+        Args:
+           label (4D tensor) : Input label tensor.
+           embedder (obj) : Embedding network.
+        Returns:
+           cond_maps (list) : List of conditional inputs.
+        """
+        if not self.use_embed:
+            return [label] * (self.num_layers + 1)
+        embedded_label = embedder(label)
+        cond_maps = [embedded_label]
+        cond_maps = [[m[i] for m in cond_maps] for i in
+                     range(len(cond_maps[0]))]
+        return cond_maps
diff --git a/imaginaire/layers/__init__.py b/imaginaire/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e3f93c154678b630de93ab3d6b199204c4fd8fb
--- /dev/null
+++ b/imaginaire/layers/__init__.py
@@ -0,0 +1,27 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .conv import LinearBlock, Conv1dBlock, Conv2dBlock, Conv3dBlock, \
+    HyperConv2dBlock, MultiOutConv2dBlock, \
+    PartialConv2dBlock, PartialConv3dBlock
+from .residual import ResLinearBlock, Res1dBlock, Res2dBlock, Res3dBlock, \
+    HyperRes2dBlock, MultiOutRes2dBlock, UpRes2dBlock, DownRes2dBlock, \
+    PartialRes2dBlock, PartialRes3dBlock
+from .non_local import NonLocal2dBlock
+
+__all__ = ['Conv1dBlock', 'Conv2dBlock', 'Conv3dBlock', 'LinearBlock',
+           'HyperConv2dBlock', 'MultiOutConv2dBlock',
+           'PartialConv2dBlock', 'PartialConv3dBlock',
+           'Res1dBlock', 'Res2dBlock', 'Res3dBlock',
+           'UpRes2dBlock', 'DownRes2dBlock',
+           'ResLinearBlock', 'HyperRes2dBlock', 'MultiOutRes2dBlock',
+           'PartialRes2dBlock', 'PartialRes3dBlock',
+           'NonLocal2dBlock']
+
+try:
+    from .repvgg import RepVGG1dBlock, RepVGG2dBlock, RepVGG3dBlock
+    from .attn import MultiheadAttention
+    __all__.extend(['RepVGG1dBlock', 'RepVGG2dBlock', 'RepVGG3dBlock'])
+except:  # noqa
+    pass
diff --git a/imaginaire/layers/__pycache__/__init__.cpython-38.pyc b/imaginaire/layers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..528061b6276196a6fdf6c254450dfc5f9f9f6a89
Binary files /dev/null and b/imaginaire/layers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc b/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..754bbd93d5840bc721eaaede366f15c5903d286b
Binary files /dev/null and b/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/conv.cpython-38.pyc b/imaginaire/layers/__pycache__/conv.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da61b20f0c5a954d150297691572cfcae2b75bee
Binary files /dev/null and b/imaginaire/layers/__pycache__/conv.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/misc.cpython-38.pyc b/imaginaire/layers/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0cf730aee80241b8f869fd876e4e43102a25790
Binary files /dev/null and b/imaginaire/layers/__pycache__/misc.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/non_local.cpython-38.pyc b/imaginaire/layers/__pycache__/non_local.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9333890d786503983cdaf4be1fb7458fbc215ea0
Binary files /dev/null and b/imaginaire/layers/__pycache__/non_local.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc b/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a3a57d74453c76df61eaed8c0a4294479611dc2
Binary files /dev/null and b/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/residual.cpython-38.pyc b/imaginaire/layers/__pycache__/residual.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a64ea7db7e342783ae70a6348e6278f3fc030fc5
Binary files /dev/null and b/imaginaire/layers/__pycache__/residual.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc b/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4fffd5e3452b5dbedacfeada43c7c969a7026b7
Binary files /dev/null and b/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc differ
diff --git a/imaginaire/layers/activation_norm.py b/imaginaire/layers/activation_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..709a928eed76f456c6284f575ba47f2dff581b39
--- /dev/null
+++ b/imaginaire/layers/activation_norm.py
@@ -0,0 +1,629 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa E722
+from types import SimpleNamespace
+
+import torch
+
+try:
+    from torch.nn import SyncBatchNorm
+except ImportError:
+    from torch.nn import BatchNorm2d as SyncBatchNorm
+from torch import nn
+from torch.nn import functional as F
+from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock
+from .misc import PartialSequential, ApplyNoise
+
+
+class AdaptiveNorm(nn.Module):
+    r"""Adaptive normalization layer. The layer first normalizes the input, then
+    performs an affine transformation using parameters computed from the
+    conditional inputs.
+
+    Args:
+        num_features (int): Number of channels in the input tensor.
+        cond_dims (int): Number of channels in the conditional inputs.
+        weight_norm_type (str): Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``.
+        projection (bool): If ``True``, project the conditional input to gamma
+            and beta using a fully connected layer, otherwise directly use
+            the conditional input as gamma and beta.
+        projection_bias (bool) If ``True``, use bias in the fully connected
+            projection layer.
+        separate_projection (bool): If ``True``, we will use two different
+            layers for gamma and beta. Otherwise, we will use one layer. It
+            matters only if you apply any weight norms to this layer.
+        input_dim (int): Number of dimensions of the input tensor.
+        activation_norm_type (str):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+    """
+
+    def __init__(self, num_features, cond_dims, weight_norm_type='',
+                 projection=True,
+                 projection_bias=True,
+                 separate_projection=False,
+                 input_dim=2,
+                 activation_norm_type='instance',
+                 activation_norm_params=None,
+                 apply_noise=False,
+                 add_bias=True,
+                 input_scale=1.0,
+                 init_gain=1.0):
+        super().__init__()
+        if activation_norm_params is None:
+            activation_norm_params = SimpleNamespace(affine=False)
+        self.norm = get_activation_norm_layer(num_features,
+                                              activation_norm_type,
+                                              input_dim,
+                                              **vars(activation_norm_params))
+        if apply_noise:
+            self.noise_layer = ApplyNoise()
+        else:
+            self.noise_layer = None
+
+        if projection:
+            if separate_projection:
+                self.fc_gamma = \
+                    LinearBlock(cond_dims, num_features,
+                                weight_norm_type=weight_norm_type,
+                                bias=projection_bias)
+                self.fc_beta = \
+                    LinearBlock(cond_dims, num_features,
+                                weight_norm_type=weight_norm_type,
+                                bias=projection_bias)
+            else:
+                self.fc = LinearBlock(cond_dims, num_features * 2,
+                                      weight_norm_type=weight_norm_type,
+                                      bias=projection_bias)
+
+        self.projection = projection
+        self.separate_projection = separate_projection
+        self.input_scale = input_scale
+        self.add_bias = add_bias
+        self.conditional = True
+        self.init_gain = init_gain
+
+    def forward(self, x, y, noise=None, **_kwargs):
+        r"""Adaptive Normalization forward.
+
+        Args:
+            x (N x C1 x * tensor): Input tensor.
+            y (N x C2 tensor): Conditional information.
+        Returns:
+            out (N x C1 x * tensor): Output tensor.
+        """
+        y = y * self.input_scale
+        if self.projection:
+            if self.separate_projection:
+                gamma = self.fc_gamma(y)
+                beta = self.fc_beta(y)
+                for _ in range(x.dim() - gamma.dim()):
+                    gamma = gamma.unsqueeze(-1)
+                    beta = beta.unsqueeze(-1)
+            else:
+                y = self.fc(y)
+                for _ in range(x.dim() - y.dim()):
+                    y = y.unsqueeze(-1)
+                gamma, beta = y.chunk(2, 1)
+        else:
+            for _ in range(x.dim() - y.dim()):
+                y = y.unsqueeze(-1)
+            gamma, beta = y.chunk(2, 1)
+        if self.norm is not None:
+            x = self.norm(x)
+        if self.noise_layer is not None:
+            x = self.noise_layer(x, noise=noise)
+        if self.add_bias:
+            x = torch.addcmul(beta, x, 1 + gamma)
+            return x
+        else:
+            return x * (1 + gamma), beta.squeeze(3).squeeze(2)
+
+
+class SpatiallyAdaptiveNorm(nn.Module):
+    r"""Spatially Adaptive Normalization (SPADE) initialization.
+
+    Args:
+        num_features (int) : Number of channels in the input tensor.
+        cond_dims (int or list of int) : List of numbers of channels
+            in the input.
+        num_filters (int): Number of filters in SPADE.
+        kernel_size (int): Kernel size of the convolutional filters in
+            the SPADE layer.
+        weight_norm_type (str): Type of weight normalization.
+            ``'none'``, ``'spectral'``, or ``'weight'``.
+        separate_projection (bool): If ``True``, we will use two different
+            layers for gamma and beta. Otherwise, we will use one layer. It
+            matters only if you apply any weight norms to this layer.
+        activation_norm_type (str):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+    """
+
+    def __init__(self,
+                 num_features,
+                 cond_dims,
+                 num_filters=128,
+                 kernel_size=3,
+                 weight_norm_type='',
+                 separate_projection=False,
+                 activation_norm_type='sync_batch',
+                 activation_norm_params=None,
+                 bias_only=False,
+                 partial=False,
+                 interpolation='nearest'):
+        super().__init__()
+        if activation_norm_params is None:
+            activation_norm_params = SimpleNamespace(affine=False)
+        padding = kernel_size // 2
+        self.separate_projection = separate_projection
+        self.mlps = nn.ModuleList()
+        self.gammas = nn.ModuleList()
+        self.betas = nn.ModuleList()
+        self.bias_only = bias_only
+        self.interpolation = interpolation
+
+        # Make cond_dims a list.
+        if type(cond_dims) != list:
+            cond_dims = [cond_dims]
+
+        # Make num_filters a list.
+        if not isinstance(num_filters, list):
+            num_filters = [num_filters] * len(cond_dims)
+        else:
+            assert len(num_filters) >= len(cond_dims)
+
+        # Make partial a list.
+        if not isinstance(partial, list):
+            partial = [partial] * len(cond_dims)
+        else:
+            assert len(partial) >= len(cond_dims)
+
+        for i, cond_dim in enumerate(cond_dims):
+            mlp = []
+            conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock
+            sequential = PartialSequential if partial[i] else nn.Sequential
+
+            if num_filters[i] > 0:
+                mlp += [conv_block(cond_dim,
+                                   num_filters[i],
+                                   kernel_size,
+                                   padding=padding,
+                                   weight_norm_type=weight_norm_type,
+                                   nonlinearity='relu')]
+            mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i]
+
+            if self.separate_projection:
+                if partial[i]:
+                    raise NotImplementedError(
+                        'Separate projection not yet implemented for ' +
+                        'partial conv')
+                self.mlps.append(nn.Sequential(*mlp))
+                self.gammas.append(
+                    conv_block(mlp_ch, num_features,
+                               kernel_size,
+                               padding=padding,
+                               weight_norm_type=weight_norm_type))
+                self.betas.append(
+                    conv_block(mlp_ch, num_features,
+                               kernel_size,
+                               padding=padding,
+                               weight_norm_type=weight_norm_type))
+            else:
+                mlp += [conv_block(mlp_ch, num_features * 2, kernel_size,
+                                   padding=padding,
+                                   weight_norm_type=weight_norm_type)]
+                self.mlps.append(sequential(*mlp))
+
+        self.norm = get_activation_norm_layer(num_features,
+                                              activation_norm_type,
+                                              2,
+                                              **vars(activation_norm_params))
+        self.conditional = True
+
+    def forward(self, x, *cond_inputs, **_kwargs):
+        r"""Spatially Adaptive Normalization (SPADE) forward.
+
+        Args:
+            x (N x C1 x H x W tensor) : Input tensor.
+            cond_inputs (list of tensors) : Conditional maps for SPADE.
+        Returns:
+            output (4D tensor) : Output tensor.
+        """
+        output = self.norm(x) if self.norm is not None else x
+        for i in range(len(cond_inputs)):
+            if cond_inputs[i] is None:
+                continue
+            label_map = F.interpolate(cond_inputs[i], size=x.size()[2:], mode=self.interpolation)
+            if self.separate_projection:
+                hidden = self.mlps[i](label_map)
+                gamma = self.gammas[i](hidden)
+                beta = self.betas[i](hidden)
+            else:
+                affine_params = self.mlps[i](label_map)
+                gamma, beta = affine_params.chunk(2, dim=1)
+            if self.bias_only:
+                output = output + beta
+            else:
+                output = output * (1 + gamma) + beta
+        return output
+
+
+class DualAdaptiveNorm(nn.Module):
+    def __init__(self,
+                 num_features,
+                 cond_dims,
+                 projection_bias=True,
+                 weight_norm_type='',
+                 activation_norm_type='instance',
+                 activation_norm_params=None,
+                 apply_noise=False,
+                 bias_only=False,
+                 init_gain=1.0,
+                 fc_scale=None,
+                 is_spatial=None):
+        super().__init__()
+        if activation_norm_params is None:
+            activation_norm_params = SimpleNamespace(affine=False)
+        self.mlps = nn.ModuleList()
+        self.gammas = nn.ModuleList()
+        self.betas = nn.ModuleList()
+        self.bias_only = bias_only
+
+        # Make cond_dims a list.
+        if type(cond_dims) != list:
+            cond_dims = [cond_dims]
+
+        if is_spatial is None:
+            is_spatial = [False for _ in range(len(cond_dims))]
+        self.is_spatial = is_spatial
+
+        for cond_dim, this_is_spatial in zip(cond_dims, is_spatial):
+            kwargs = dict(weight_norm_type=weight_norm_type,
+                          bias=projection_bias,
+                          init_gain=init_gain,
+                          output_scale=fc_scale)
+            if this_is_spatial:
+                self.gammas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
+                self.betas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
+            else:
+                self.gammas.append(LinearBlock(cond_dim, num_features, **kwargs))
+                self.betas.append(LinearBlock(cond_dim, num_features, **kwargs))
+
+        self.norm = get_activation_norm_layer(num_features,
+                                              activation_norm_type,
+                                              2,
+                                              **vars(activation_norm_params))
+        self.conditional = True
+
+    def forward(self, x, *cond_inputs, **_kwargs):
+        assert len(cond_inputs) == len(self.gammas)
+        output = self.norm(x) if self.norm is not None else x
+        for cond, gamma_layer, beta_layer in zip(cond_inputs, self.gammas, self.betas):
+            if cond is None:
+                continue
+            gamma = gamma_layer(cond)
+            beta = beta_layer(cond)
+            if cond.dim() == 4 and gamma.shape != x.shape:
+                gamma = F.interpolate(gamma, size=x.size()[2:], mode='bilinear')
+                beta = F.interpolate(beta, size=x.size()[2:], mode='bilinear')
+            elif cond.dim() == 2:
+                gamma = gamma[:, :, None, None]
+                beta = beta[:, :, None, None]
+            if self.bias_only:
+                output = output + beta
+            else:
+                output = output * (1 + gamma) + beta
+        return output
+
+
+class HyperSpatiallyAdaptiveNorm(nn.Module):
+    r"""Spatially Adaptive Normalization (SPADE) initialization.
+
+    Args:
+        num_features (int) : Number of channels in the input tensor.
+        cond_dims (int or list of int) : List of numbers of channels
+            in the conditional input.
+        num_filters (int): Number of filters in SPADE.
+        kernel_size (int): Kernel size of the convolutional filters in
+            the SPADE layer.
+        weight_norm_type (str): Type of weight normalization.
+            ``'none'``, ``'spectral'``, or ``'weight'``.
+        activation_norm_type (str):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``.
+        is_hyper (bool): Whether to use hyper SPADE.
+    """
+
+    def __init__(self, num_features, cond_dims,
+                 num_filters=0, kernel_size=3,
+                 weight_norm_type='',
+                 activation_norm_type='sync_batch', is_hyper=True):
+        super().__init__()
+        padding = kernel_size // 2
+        self.mlps = nn.ModuleList()
+        if type(cond_dims) != list:
+            cond_dims = [cond_dims]
+
+        for i, cond_dim in enumerate(cond_dims):
+            mlp = []
+            if not is_hyper or (i != 0):
+                if num_filters > 0:
+                    mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size,
+                                        padding=padding,
+                                        weight_norm_type=weight_norm_type,
+                                        nonlinearity='relu')]
+                mlp_ch = cond_dim if num_filters == 0 else num_filters
+                mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size,
+                                    padding=padding,
+                                    weight_norm_type=weight_norm_type)]
+                mlp = nn.Sequential(*mlp)
+            else:
+                if num_filters > 0:
+                    raise ValueError('Multi hyper layer not supported yet.')
+                mlp = HyperConv2d(padding=padding)
+            self.mlps.append(mlp)
+
+        self.norm = get_activation_norm_layer(num_features,
+                                              activation_norm_type,
+                                              2,
+                                              affine=False)
+
+        self.conditional = True
+
+    def forward(self, x, *cond_inputs,
+                norm_weights=(None, None), **_kwargs):
+        r"""Spatially Adaptive Normalization (SPADE) forward.
+
+        Args:
+            x (4D tensor) : Input tensor.
+            cond_inputs (list of tensors) : Conditional maps for SPADE.
+            norm_weights (5D tensor or list of tensors): conv weights or
+            [weights, biases].
+        Returns:
+            output (4D tensor) : Output tensor.
+        """
+        output = self.norm(x)
+        for i in range(len(cond_inputs)):
+            if cond_inputs[i] is None:
+                continue
+            if type(cond_inputs[i]) == list:
+                cond_input, mask = cond_inputs[i]
+                mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=False)
+            else:
+                cond_input = cond_inputs[i]
+                mask = None
+            label_map = F.interpolate(cond_input, size=x.size()[2:])
+            if norm_weights is None or norm_weights[0] is None or i != 0:
+                affine_params = self.mlps[i](label_map)
+            else:
+                affine_params = self.mlps[i](label_map,
+                                             conv_weights=norm_weights)
+            gamma, beta = affine_params.chunk(2, dim=1)
+            if mask is not None:
+                gamma = gamma * (1 - mask)
+                beta = beta * (1 - mask)
+            output = output * (1 + gamma) + beta
+        return output
+
+
+class LayerNorm2d(nn.Module):
+    r"""Layer Normalization as introduced in
+    https://arxiv.org/abs/1607.06450.
+    This is the usual way to apply layer normalization in CNNs.
+    Note that unlike the pytorch implementation which applies per-element
+    scale and bias, here it applies per-channel scale and bias, similar to
+    batch/instance normalization.
+
+    Args:
+        num_features (int): Number of channels in the input tensor.
+        eps (float, optional, default=1e-5): a value added to the
+            denominator for numerical stability.
+        affine (bool, optional, default=False): If ``True``, performs
+            affine transformation after normalization.
+    """
+
+    def __init__(self, num_features, eps=1e-5, channel_only=False, affine=True):
+        super(LayerNorm2d, self).__init__()
+        self.num_features = num_features
+        self.affine = affine
+        self.eps = eps
+        self.channel_only = channel_only
+
+        if self.affine:
+            self.gamma = nn.Parameter(torch.Tensor(num_features).fill_(1.0))
+            self.beta = nn.Parameter(torch.zeros(num_features))
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+        """
+        shape = [-1] + [1] * (x.dim() - 1)
+        if self.channel_only:
+            mean = x.mean(1, keepdim=True)
+            std = x.std(1, keepdim=True)
+        else:
+            mean = x.view(x.size(0), -1).mean(1).view(*shape)
+            std = x.view(x.size(0), -1).std(1).view(*shape)
+
+        x = (x - mean) / (std + self.eps)
+
+        if self.affine:
+            shape = [1, -1] + [1] * (x.dim() - 2)
+            x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+        return x
+
+
+class ScaleNorm(nn.Module):
+    r"""Scale normalization:
+    "Transformers without Tears: Improving the Normalization of Self-Attention"
+    Modified from:
+    https://github.com/tnq177/transformers_without_tears
+    """
+
+    def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
+        super().__init__()
+        # scale = num_features ** 0.5
+        if learned_scale:
+            self.scale = nn.Parameter(torch.tensor(1.))
+        else:
+            self.scale = 1.
+        # self.num_features = num_features
+        self.dim = dim
+        self.eps = eps
+        self.learned_scale = learned_scale
+
+    def forward(self, x):
+        # noinspection PyArgumentList
+        scale = self.scale * torch.rsqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
+        return x * scale
+
+    def extra_repr(self):
+        s = 'learned_scale={learned_scale}'
+        return s.format(**self.__dict__)
+
+
+class PixelNorm(ScaleNorm):
+    def __init__(self, learned_scale=False, eps=1e-5, **_kwargs):
+        super().__init__(1, learned_scale, eps)
+
+
+class SplitMeanStd(nn.Module):
+    def __init__(self, num_features, eps=1e-5, **kwargs):
+        super().__init__()
+        self.num_features = num_features
+        self.eps = eps
+        self.multiple_outputs = True
+
+    def forward(self, x):
+        b, c, h, w = x.size()
+        mean = x.view(b, c, -1).mean(-1)[:, :, None, None]
+        var = x.view(b, c, -1).var(-1)[:, :, None, None]
+        std = torch.sqrt(var + self.eps)
+
+        # x = (x - mean) / std
+        return x, torch.cat((mean, std), dim=1)
+
+
+class ScaleNorm(nn.Module):
+    r"""Scale normalization:
+    "Transformers without Tears: Improving the Normalization of Self-Attention"
+    Modified from:
+    https://github.com/tnq177/transformers_without_tears
+    """
+
+    def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
+        super().__init__()
+        # scale = num_features ** 0.5
+        if learned_scale:
+            self.scale = nn.Parameter(torch.tensor(1.))
+        else:
+            self.scale = 1.
+        # self.num_features = num_features
+        self.dim = dim
+        self.eps = eps
+        self.learned_scale = learned_scale
+
+    def forward(self, x):
+        # noinspection PyArgumentList
+        scale = self.scale * torch.rsqrt(
+            torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
+        return x * scale
+
+    def extra_repr(self):
+        s = 'learned_scale={learned_scale}'
+        return s.format(**self.__dict__)
+
+
+class PixelLayerNorm(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+        self.norm = nn.LayerNorm(*args, **kwargs)
+
+    def forward(self, x):
+        if x.dim() == 4:
+            b, c, h, w = x.shape
+            return self.norm(x.permute(0, 2, 3, 1).view(-1, c)).view(b, h, w, c).permute(0, 3, 1, 2)
+        else:
+            return self.norm(x)
+
+
+def get_activation_norm_layer(num_features, norm_type, input_dim, **norm_params):
+    r"""Return an activation normalization layer.
+
+    Args:
+        num_features (int): Number of feature channels.
+        norm_type (str):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        input_dim (int): Number of input dimensions.
+        norm_params: Arbitrary keyword arguments that will be used to
+            initialize the activation normalization.
+    """
+    input_dim = max(input_dim, 1)  # Norm1d works with both 0d and 1d inputs
+
+    if norm_type == 'none' or norm_type == '':
+        norm_layer = None
+    elif norm_type == 'batch':
+        norm = getattr(nn, 'BatchNorm%dd' % input_dim)
+        norm_layer = norm(num_features, **norm_params)
+    elif norm_type == 'instance':
+        affine = norm_params.pop('affine', True)  # Use affine=True by default
+        norm = getattr(nn, 'InstanceNorm%dd' % input_dim)
+        norm_layer = norm(num_features, affine=affine, **norm_params)
+    elif norm_type == 'sync_batch':
+        norm_layer = SyncBatchNorm(num_features, **norm_params)
+    elif norm_type == 'layer':
+        norm_layer = nn.LayerNorm(num_features, **norm_params)
+    elif norm_type == 'layer_2d':
+        norm_layer = LayerNorm2d(num_features, **norm_params)
+    elif norm_type == 'pixel_layer':
+        elementwise_affine = norm_params.pop('affine', True)  # Use affine=True by default
+        norm_layer = PixelLayerNorm(num_features, elementwise_affine=elementwise_affine, **norm_params)
+    elif norm_type == 'scale':
+        norm_layer = ScaleNorm(**norm_params)
+    elif norm_type == 'pixel':
+        norm_layer = PixelNorm(**norm_params)
+        import imaginaire.config
+        if imaginaire.config.USE_JIT:
+            norm_layer = torch.jit.script(norm_layer)
+    elif norm_type == 'group':
+        num_groups = norm_params.pop('num_groups', 4)
+        norm_layer = nn.GroupNorm(num_channels=num_features, num_groups=num_groups, **norm_params)
+    elif norm_type == 'adaptive':
+        norm_layer = AdaptiveNorm(num_features, **norm_params)
+    elif norm_type == 'dual_adaptive':
+        norm_layer = DualAdaptiveNorm(num_features, **norm_params)
+    elif norm_type == 'spatially_adaptive':
+        if input_dim != 2:
+            raise ValueError('Spatially adaptive normalization layers '
+                             'only supports 2D input')
+        norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params)
+    elif norm_type == 'hyper_spatially_adaptive':
+        if input_dim != 2:
+            raise ValueError('Spatially adaptive normalization layers '
+                             'only supports 2D input')
+        norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params)
+    else:
+        raise ValueError('Activation norm layer %s '
+                         'is not recognized' % norm_type)
+    return norm_layer
diff --git a/imaginaire/layers/conv.py b/imaginaire/layers/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..499fc0442b77e3183225c3529a4e3590dab0bc57
--- /dev/null
+++ b/imaginaire/layers/conv.py
@@ -0,0 +1,1377 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .misc import ApplyNoise
+from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
+
+
+class _BaseConvBlock(nn.Module):
+    r"""An abstract wrapper class that wraps a torch convolution or linear layer
+    with normalization and nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity,
+                 inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale,
+                 init_gain):
+        super().__init__()
+        from .nonlinearity import get_nonlinearity_layer
+        from .weight_norm import get_weight_norm_layer
+        from .activation_norm import get_activation_norm_layer
+        self.weight_norm_type = weight_norm_type
+        self.stride = stride
+        self.clamp = clamp
+        self.init_gain = init_gain
+
+        # Nonlinearity layer.
+        if 'fused' in nonlinearity:
+            # Fusing nonlinearity with bias.
+            lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
+            conv_before_nonlinearity = order.find('C') < order.find('A')
+            if conv_before_nonlinearity:
+                assert bias is True
+                bias = False
+            channel = out_channels if conv_before_nonlinearity else in_channels
+            nonlinearity_layer = get_nonlinearity_layer(
+                nonlinearity, inplace=inplace_nonlinearity,
+                num_channels=channel, lr_mul=lr_mul)
+        else:
+            nonlinearity_layer = get_nonlinearity_layer(
+                nonlinearity, inplace=inplace_nonlinearity)
+
+        # Noise injection layer.
+        if apply_noise:
+            order = order.replace('C', 'CG')
+            noise_layer = ApplyNoise()
+        else:
+            noise_layer = None
+
+        # Convolutional layer.
+        if blur:
+            assert blur_kernel is not None
+            if stride == 2:
+                # Blur - Conv - Noise - Activate
+                p = (len(blur_kernel) - 2) + (kernel_size - 1)
+                pad0, pad1 = (p + 1) // 2, p // 2
+                padding = 0
+                blur_layer = Blur(
+                    blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
+                )
+                order = order.replace('C', 'BC')
+            elif stride == 0.5:
+                # Conv - Blur - Noise - Activate
+                padding = 0
+                p = (len(blur_kernel) - 2) - (kernel_size - 1)
+                pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
+                blur_layer = Blur(
+                    blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
+                )
+                order = order.replace('C', 'CB')
+            elif stride == 1:
+                # No blur for now
+                blur_layer = nn.Identity()
+            else:
+                raise NotImplementedError
+        else:
+            blur_layer = nn.Identity()
+
+        if weight_norm_params is None:
+            weight_norm_params = SimpleNamespace()
+        weight_norm = get_weight_norm_layer(
+            weight_norm_type, **vars(weight_norm_params))
+        conv_layer = weight_norm(self._get_conv_layer(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            groups, bias, padding_mode, input_dim))
+
+        # Normalization layer.
+        conv_before_norm = order.find('C') < order.find('N')
+        norm_channels = out_channels if conv_before_norm else in_channels
+        if activation_norm_params is None:
+            activation_norm_params = SimpleNamespace()
+        activation_norm_layer = get_activation_norm_layer(
+            norm_channels,
+            activation_norm_type,
+            input_dim,
+            **vars(activation_norm_params))
+
+        # Mapping from operation names to layers.
+        mappings = {'C': {'conv': conv_layer},
+                    'N': {'norm': activation_norm_layer},
+                    'A': {'nonlinearity': nonlinearity_layer}}
+        mappings.update({'B': {'blur': blur_layer}})
+        mappings.update({'G': {'noise': noise_layer}})
+
+        # All layers in order.
+        self.layers = nn.ModuleDict()
+        for op in order:
+            if list(mappings[op].values())[0] is not None:
+                self.layers.update(mappings[op])
+
+        # Whether this block expects conditional inputs.
+        self.conditional = \
+            getattr(conv_layer, 'conditional', False) or \
+            getattr(activation_norm_layer, 'conditional', False)
+
+        # Scale the output by a learnable scaler parameter.
+        if output_scale is not None:
+            self.output_scale = nn.Parameter(torch.tensor(output_scale))
+        else:
+            self.register_parameter("output_scale", None)
+
+    def forward(self, x, *cond_inputs, **kw_cond_inputs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+            kw_cond_inputs (dict) : Keyword conditional inputs.
+        """
+        for key, layer in self.layers.items():
+            if getattr(layer, 'conditional', False):
+                # Layers that require conditional inputs.
+                x = layer(x, *cond_inputs, **kw_cond_inputs)
+            else:
+                x = layer(x)
+            if self.clamp is not None and isinstance(layer, nn.Conv2d):
+                x.clamp_(max=self.clamp)
+            if key == 'conv':
+                if self.output_scale is not None:
+                    x = x * self.output_scale
+        return x
+
+    def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+                        padding, dilation, groups, bias, padding_mode,
+                        input_dim):
+        # Returns the convolutional layer.
+        if input_dim == 0:
+            layer = nn.Linear(in_channels, out_channels, bias)
+        else:
+            if stride < 1:  # Fractionally-strided convolution.
+                padding_mode = 'zeros'
+                assert padding == 0
+                layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
+                stride = round(1 / stride)
+            else:
+                layer_type = getattr(nn, f'Conv{input_dim}d')
+            layer = layer_type(
+                in_channels, out_channels, kernel_size, stride, padding,
+                dilation=dilation, groups=groups, bias=bias,
+                padding_mode=padding_mode
+            )
+
+        return layer
+
+    def __repr__(self):
+        main_str = self._get_name() + '('
+        child_lines = []
+        for name, layer in self.layers.items():
+            mod_str = repr(layer)
+            if name == 'conv' and self.weight_norm_type != 'none' and \
+                    self.weight_norm_type != '':
+                mod_str = mod_str[:-1] + \
+                          ', weight_norm={}'.format(self.weight_norm_type) + ')'
+            if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
+                mod_str = mod_str[:-1] + \
+                          ', lr_mul={}'.format(layer.base_lr_mul) + ')'
+            mod_str = self._addindent(mod_str, 2)
+            child_lines.append(mod_str)
+        if len(child_lines) == 1:
+            main_str += child_lines[0]
+        else:
+            main_str += '\n  ' + '\n  '.join(child_lines) + '\n'
+
+        main_str += ')'
+        return main_str
+
+    @staticmethod
+    def _addindent(s_, numSpaces):
+        s = s_.split('\n')
+        # don't do anything for single-line stuff
+        if len(s) == 1:
+            return s_
+        first = s.pop(0)
+        s = [(numSpaces * ' ') + line for line in s]
+        s = '\n'.join(s)
+        s = first + '\n' + s
+        return s
+
+
+class ModulatedConv2dBlock(_BaseConvBlock):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=True, blur=True, order='CNA', demodulate=True,
+                 eps=True, style_dim=None, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0):
+        self.eps = eps
+        self.demodulate = demodulate
+        assert style_dim is not None
+
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity, apply_noise, blur,
+                         order, 2, clamp, blur_kernel, output_scale, init_gain)
+        self.modulation = LinearBlock(style_dim, in_channels,
+                                      weight_norm_type=weight_norm_type,
+                                      weight_norm_params=weight_norm_params)
+
+    def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+                        padding, dilation, groups, bias, padding_mode,
+                        input_dim):
+        assert input_dim == 2
+        layer = ModulatedConv2d(
+            in_channels, out_channels, kernel_size, stride, padding,
+            dilation, groups, bias, padding_mode, self.demodulate, self.eps)
+        return layer
+
+    def forward(self, x, *cond_inputs, **kw_cond_inputs):
+        for layer in self.layers.values():
+            if getattr(layer, 'conditional', False):
+                # Layers that require conditional inputs.
+                assert len(cond_inputs) == 1
+                style = cond_inputs[0]
+                x = layer(
+                    x, self.modulation(style), **kw_cond_inputs
+                )
+            else:
+                x = layer(x)
+            if self.clamp is not None and isinstance(layer, ModulatedConv2d):
+                x.clamp_(max=self.clamp)
+        return x
+
+    def __repr__(self):
+        main_str = self._get_name() + '('
+        child_lines = []
+        for name, layer in self.layers.items():
+            mod_str = repr(layer)
+            if name == 'conv' and self.weight_norm_type != 'none' and \
+                    self.weight_norm_type != '':
+                mod_str = mod_str[:-1] + \
+                          ', weight_norm={}'.format(self.weight_norm_type) + \
+                          ', demodulate={}'.format(self.demodulate) + ')'
+            mod_str = self._addindent(mod_str, 2)
+            child_lines.append(mod_str)
+        child_lines.append(
+            self._addindent('Modulation(' + repr(self.modulation) + ')', 2)
+        )
+        if len(child_lines) == 1:
+            main_str += child_lines[0]
+        else:
+            main_str += '\n  ' + '\n  '.join(child_lines) + '\n'
+
+        main_str += ')'
+        return main_str
+
+
+class ModulatedConv2d(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
+                 dilation, groups, bias, padding_mode, demodulate=True,
+                 eps=1e-8):
+        # in_channels, out_channels, kernel_size, stride, padding,
+        # dilation, groups, bias, padding_mode
+        assert dilation == 1 and groups == 1
+
+        super().__init__()
+
+        self.eps = eps
+        self.kernel_size = kernel_size
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.padding = padding
+        self.stride = stride
+        self.padding_mode = padding_mode
+        # kernel_size // 2
+        # assert self.padding == padding
+
+        self.weight = nn.Parameter(
+            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
+        )
+
+        if bias:
+            self.bias = nn.Parameter(torch.Tensor(out_channels))
+        else:
+            # noinspection PyTypeChecker
+            self.register_parameter('bias', None)
+
+        # self.modulation = LinearBlock(style_dim, in_channels,
+        #                               weight_norm_type=weight_norm_type)
+        self.demodulate = demodulate
+        self.conditional = True
+
+    def forward(self, x, style, **_kwargs):
+        batch, in_channel, height, width = x.shape
+
+        # style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+        # We assume the modulation layer is outside this module.
+        style = style.view(batch, 1, in_channel, 1, 1)
+        weight = self.weight.unsqueeze(0) * style
+
+        if self.demodulate:
+            demod = torch.rsqrt(
+                weight.pow(2).sum([2, 3, 4]) + self.eps)
+            weight = weight * demod.view(batch, self.out_channels, 1, 1, 1)
+
+        weight = weight.view(
+            batch * self.out_channels,
+            in_channel, self.kernel_size, self.kernel_size
+        )
+        if self.bias is not None:
+            bias = self.bias.repeat(batch)
+        else:
+            bias = self.bias
+
+        x = x.view(1, batch * in_channel, height, width)
+
+        if self.padding_mode != 'zeros':
+            x = F.pad(x, self._reversed_padding_repeated_twice,
+                      mode=self.padding_mode)
+            padding = (0, 0)
+        else:
+            padding = self.padding
+
+        if self.stride == 0.5:
+            weight = weight.view(
+                batch, self.out_channels, in_channel,
+                self.kernel_size, self.kernel_size
+            )
+            weight = weight.transpose(1, 2).reshape(
+                batch * in_channel, self.out_channels,
+                self.kernel_size, self.kernel_size
+            )
+            out = F.conv_transpose2d(
+                x, weight, bias, padding=padding, stride=2, groups=batch
+            )
+
+        elif self.stride == 2:
+            out = F.conv2d(
+                x, weight, bias, padding=padding, stride=2, groups=batch
+            )
+
+        else:
+            out = F.conv2d(x, weight, bias, padding=padding, groups=batch)
+
+        _, _, height, width = out.shape
+        out = out.view(batch, self.out_channels, height, width)
+
+        return out
+
+    def extra_repr(self):
+        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
+             ', stride={stride}')
+        if self.bias is None:
+            s += ', bias=False'
+        if self.padding_mode != 'zeros':
+            s += ', padding_mode={padding_mode}'
+        return s.format(**self.__dict__)
+
+
+class LinearBlock(_BaseConvBlock):
+    r"""A Wrapper class that wraps ``torch.nn.Linear`` with normalization and
+    nonlinearity.
+
+    Args:
+        in_features (int): Number of channels in the input tensor.
+        out_features (int): Number of channels in the output tensor.
+        bias (bool, optional, default=True):
+            If ``True``, adds a learnable bias to the output.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layer.
+        apply_noise (bool, optional, default=False): If ``True``, add
+            Gaussian noise with learnable magnitude after the
+            fully-connected layer.
+        order (str, optional, default='CNA'): Order of operations.
+            ``'C'``: fully-connected,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+            For example, a block initialized with ``order='CNA'`` will
+            do convolution first, then normalization, then nonlinearity.
+    """
+
+    def __init__(self, in_features, out_features, bias=True,
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None,
+                 init_gain=1.0, **_kwargs):
+        if bool(_kwargs):
+            warnings.warn(f"Unused keyword arguments {_kwargs}")
+        super().__init__(in_features, out_features, None, None,
+                         None, None, None, bias,
+                         None, weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity, apply_noise,
+                         False, order, 0, clamp, blur_kernel, output_scale,
+                         init_gain)
+
+
+class EmbeddingBlock(_BaseConvBlock):
+    def __init__(self, in_features, out_features, bias=True,
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=False, order='CNA', clamp=None, output_scale=None,
+                 init_gain=1.0, **_kwargs):
+        if bool(_kwargs):
+            warnings.warn(f"Unused keyword arguments {_kwargs}")
+        super().__init__(in_features, out_features, None, None,
+                         None, None, None, bias,
+                         None, weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity, apply_noise,
+                         False, order, 0, clamp, None, output_scale,
+                         init_gain)
+
+    def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+                        padding, dilation, groups, bias, padding_mode,
+                        input_dim):
+        assert input_dim == 0
+        return nn.Embedding(in_channels, out_channels)
+
+
+class Embedding2dBlock(_BaseConvBlock):
+    def __init__(self, in_features, out_features, bias=True,
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=False, order='CNA', clamp=None, output_scale=None,
+                 init_gain=1.0, **_kwargs):
+        if bool(_kwargs):
+            warnings.warn(f"Unused keyword arguments {_kwargs}")
+        super().__init__(in_features, out_features, None, None,
+                         None, None, None, bias,
+                         None, weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity, apply_noise,
+                         False, order, 0, clamp, None, output_scale,
+                         init_gain)
+
+    def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+                        padding, dilation, groups, bias, padding_mode,
+                        input_dim):
+        assert input_dim == 0
+        return Embedding2d(in_channels, out_channels)
+
+
+class Conv1dBlock(_BaseConvBlock):
+    r"""A Wrapper class that wraps ``torch.nn.Conv1d`` with normalization and
+    nonlinearity.
+
+    Args:
+        in_channels (int): Number of channels in the input tensor.
+        out_channels (int): Number of channels in the output tensor.
+        kernel_size (int or tuple): Size of the convolving kernel.
+        stride (int or float or tuple, optional, default=1):
+            Stride of the convolution.
+        padding (int or tuple, optional, default=0):
+            Zero-padding added to both sides of the input.
+        dilation (int or tuple, optional, default=1):
+            Spacing between kernel elements.
+        groups (int, optional, default=1): Number of blocked connections
+            from input channels to output channels.
+        bias (bool, optional, default=True):
+            If ``True``, adds a learnable bias to the output.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layer.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        order (str, optional, default='CNA'): Order of operations.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+            For example, a block initialized with ``order='CNA'`` will
+            do convolution first, then normalization, then nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=False, blur=False, order='CNA', clamp=None, output_scale=None, init_gain=1.0, **_kwargs):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity, apply_noise,
+                         blur, order, 1, clamp, None, output_scale, init_gain)
+
+
+class Conv2dBlock(_BaseConvBlock):
+    r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
+    nonlinearity.
+
+    Args:
+        in_channels (int): Number of channels in the input tensor.
+        out_channels (int): Number of channels in the output tensor.
+        kernel_size (int or tuple): Size of the convolving kernel.
+        stride (int or float or tuple, optional, default=1):
+            Stride of the convolution.
+        padding (int or tuple, optional, default=0):
+            Zero-padding added to both sides of the input.
+        dilation (int or tuple, optional, default=1):
+            Spacing between kernel elements.
+        groups (int, optional, default=1): Number of blocked connections
+            from input channels to output channels.
+        bias (bool, optional, default=True):
+            If ``True``, adds a learnable bias to the output.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layer.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        order (str, optional, default='CNA'): Order of operations.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+            For example, a block initialized with ``order='CNA'`` will
+            do convolution first, then normalization, then nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1),
+                 output_scale=None, init_gain=1.0):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity,
+                         apply_noise, blur, order, 2, clamp, blur_kernel, output_scale, init_gain)
+
+
+class Conv3dBlock(_BaseConvBlock):
+    r"""A Wrapper class that wraps ``torch.nn.Conv3d`` with normalization and
+    nonlinearity.
+
+    Args:
+        in_channels (int): Number of channels in the input tensor.
+        out_channels (int): Number of channels in the output tensor.
+        kernel_size (int or tuple): Size of the convolving kernel.
+        stride (int or float or tuple, optional, default=1):
+            Stride of the convolution.
+        padding (int or tuple, optional, default=0):
+            Zero-padding added to both sides of the input.
+        dilation (int or tuple, optional, default=1):
+            Spacing between kernel elements.
+        groups (int, optional, default=1): Number of blocked connections
+            from input channels to output channels.
+        bias (bool, optional, default=True):
+            If ``True``, adds a learnable bias to the output.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layer.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        order (str, optional, default='CNA'): Order of operations.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+            For example, a block initialized with ``order='CNA'`` will
+            do convolution first, then normalization, then nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None,
+                 init_gain=1.0):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity,
+                         apply_noise, blur, order, 3, clamp, blur_kernel, output_scale, init_gain)
+
+
+class _BaseHyperConvBlock(_BaseConvBlock):
+    r"""An abstract wrapper class that wraps a hyper convolutional layer
+    with normalization and nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride,
+                 padding, dilation, groups, bias,
+                 padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 nonlinearity, inplace_nonlinearity, apply_noise, blur,
+                 is_hyper_conv, is_hyper_norm, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1),
+                 output_scale=None, init_gain=1.0):
+        self.is_hyper_conv = is_hyper_conv
+        if is_hyper_conv:
+            weight_norm_type = 'none'
+        if is_hyper_norm:
+            activation_norm_type = 'hyper_' + activation_norm_type
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity, apply_noise, blur,
+                         order, input_dim, clamp, blur_kernel, output_scale, init_gain)
+
+    def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+                        padding, dilation, groups, bias, padding_mode,
+                        input_dim):
+        if input_dim == 0:
+            raise ValueError('HyperLinearBlock is not supported.')
+        else:
+            name = 'HyperConv' if self.is_hyper_conv else 'nn.Conv'
+            layer_type = eval(name + '%dd' % input_dim)
+            layer = layer_type(
+                in_channels, out_channels, kernel_size, stride, padding,
+                dilation, groups, bias, padding_mode)
+        return layer
+
+
+class HyperConv2dBlock(_BaseHyperConvBlock):
+    r"""A Wrapper class that wraps ``HyperConv2d`` with normalization and
+    nonlinearity.
+
+    Args:
+        in_channels (int): Number of channels in the input tensor.
+        out_channels (int): Number of channels in the output tensor.
+        kernel_size (int or tuple): Size of the convolving kernel.
+        stride (int or float or tuple, optional, default=1):
+            Stride of the convolution.
+        padding (int or tuple, optional, default=0):
+            Zero-padding added to both sides of the input.
+        dilation (int or tuple, optional, default=1):
+            Spacing between kernel elements.
+        groups (int, optional, default=1): Number of blocked connections
+            from input channels to output channels.
+        bias (bool, optional, default=True):
+            If ``True``, adds a learnable bias to the output.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        is_hyper_conv (bool, optional, default=False): If ``True``, use
+            ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
+        is_hyper_norm (bool, optional, default=False): If ``True``, use
+            hyper normalizations.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layer.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        order (str, optional, default='CNA'): Order of operations.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+            For example, a block initialized with ``order='CNA'`` will
+            do convolution first, then normalization, then nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 is_hyper_conv=False, is_hyper_norm=False,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=False, blur=False, order='CNA', clamp=None):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity, apply_noise, blur,
+                         is_hyper_conv, is_hyper_norm, order, 2, clamp)
+
+
+class HyperConv2d(nn.Module):
+    r"""Hyper Conv2d initialization.
+
+    Args:
+        in_channels (int): Dummy parameter.
+        out_channels (int): Dummy parameter.
+        kernel_size (int or tuple): Dummy parameter.
+        stride (int or float or tuple, optional, default=1):
+            Stride of the convolution. Default: 1
+        padding (int or tuple, optional, default=0):
+            Zero-padding added to both sides of the input.
+        padding_mode (string, optional, default='zeros'):
+            ``'zeros'``, ``'reflect'``, ``'replicate'``
+            or ``'circular'``.
+        dilation (int or tuple, optional, default=1):
+            Spacing between kernel elements.
+        groups (int, optional, default=1): Number of blocked connections
+            from input channels to output channels.
+        bias (bool, optional, default=True): If ``True``,
+            adds a learnable bias to the output.
+    """
+
+    def __init__(self, in_channels=0, out_channels=0, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros'):
+        super().__init__()
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.groups = groups
+        self.use_bias = bias
+        self.padding_mode = padding_mode
+        self.conditional = True
+
+    def forward(self, x, *args, conv_weights=(None, None), **kwargs):
+        r"""Hyper Conv2d forward. Convolve x using the provided weight and bias.
+
+        Args:
+            x (N x C x H x W tensor): Input tensor.
+            conv_weights (N x C2 x C1 x k x k tensor or list of tensors):
+                Convolution weights or [weight, bias].
+        Returns:
+            y (N x C2 x H x W tensor): Output tensor.
+        """
+        if conv_weights is None:
+            conv_weight, conv_bias = None, None
+        elif isinstance(conv_weights, torch.Tensor):
+            conv_weight, conv_bias = conv_weights, None
+        else:
+            conv_weight, conv_bias = conv_weights
+
+        if conv_weight is None:
+            return x
+        if conv_bias is None:
+            if self.use_bias:
+                raise ValueError('bias not provided but set to true during '
+                                 'initialization')
+            conv_bias = [None] * x.size(0)
+        if self.padding_mode != 'zeros':
+            x = F.pad(x, [self.padding] * 4, mode=self.padding_mode)
+            padding = 0
+        else:
+            padding = self.padding
+
+        y = None
+        # noinspection PyArgumentList
+        for i in range(x.size(0)):
+            if self.stride >= 1:
+                yi = F.conv2d(x[i: i + 1],
+                              weight=conv_weight[i], bias=conv_bias[i],
+                              stride=self.stride, padding=padding,
+                              dilation=self.dilation, groups=self.groups)
+            else:
+                yi = F.conv_transpose2d(x[i: i + 1], weight=conv_weight[i],
+                                        bias=conv_bias[i], padding=self.padding,
+                                        stride=int(1 / self.stride),
+                                        dilation=self.dilation,
+                                        output_padding=self.padding,
+                                        groups=self.groups)
+            y = torch.cat([y, yi]) if y is not None else yi
+        return y
+
+
+class _BasePartialConvBlock(_BaseConvBlock):
+    r"""An abstract wrapper class that wraps a partial convolutional layer
+    with normalization and nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride,
+                 padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 nonlinearity, inplace_nonlinearity,
+                 multi_channel, return_mask,
+                 apply_noise, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0):
+        self.multi_channel = multi_channel
+        self.return_mask = return_mask
+        self.partial_conv = True
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity, apply_noise,
+                         False, order, input_dim, clamp, blur_kernel, output_scale, init_gain)
+
+    def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+                        padding, dilation, groups, bias, padding_mode,
+                        input_dim):
+        if input_dim == 2:
+            layer_type = PartialConv2d
+        elif input_dim == 3:
+            layer_type = PartialConv3d
+        else:
+            raise ValueError('Partial conv only supports 2D and 3D conv now.')
+        layer = layer_type(
+            in_channels, out_channels, kernel_size, stride, padding,
+            dilation, groups, bias, padding_mode,
+            multi_channel=self.multi_channel, return_mask=self.return_mask)
+        return layer
+
+    def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+            mask_in (tensor, optional, default=``None``) If not ``None``,
+                it masks the valid input region.
+            kw_cond_inputs (dict) : Keyword conditional inputs.
+        Returns:
+            (tuple):
+              - x (tensor): Output tensor.
+              - mask_out (tensor, optional): Masks the valid output region.
+        """
+        mask_out = None
+        for layer in self.layers.values():
+            if getattr(layer, 'conditional', False):
+                x = layer(x, *cond_inputs, **kw_cond_inputs)
+            elif getattr(layer, 'partial_conv', False):
+                x = layer(x, mask_in=mask_in, **kw_cond_inputs)
+                if type(x) == tuple:
+                    x, mask_out = x
+            else:
+                x = layer(x)
+
+        if mask_out is not None:
+            return x, mask_out
+        return x
+
+
+class PartialConv2dBlock(_BasePartialConvBlock):
+    r"""A Wrapper class that wraps ``PartialConv2d`` with normalization and
+    nonlinearity.
+
+    Args:
+        in_channels (int): Number of channels in the input tensor.
+        out_channels (int): Number of channels in the output tensor.
+        kernel_size (int or tuple): Size of the convolving kernel.
+        stride (int or float or tuple, optional, default=1):
+            Stride of the convolution.
+        padding (int or tuple, optional, default=0):
+            Zero-padding added to both sides of the input.
+        dilation (int or tuple, optional, default=1):
+            Spacing between kernel elements.
+        groups (int, optional, default=1): Number of blocked connections
+            from input channels to output channels.
+        bias (bool, optional, default=True):
+            If ``True``, adds a learnable bias to the output.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layer.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        order (str, optional, default='CNA'): Order of operations.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+            For example, a block initialized with ``order='CNA'`` will
+            do convolution first, then normalization, then nonlinearity.
+        multi_channel (bool, optional, default=False): If ``True``, use
+            different masks for different channels.
+        return_mask (bool, optional, default=True): If ``True``, the
+            forward call also returns a new mask.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 multi_channel=False, return_mask=True,
+                 apply_noise=False, order='CNA', clamp=None):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity,
+                         multi_channel, return_mask, apply_noise, order, 2,
+                         clamp)
+
+
+class PartialConv3dBlock(_BasePartialConvBlock):
+    r"""A Wrapper class that wraps ``PartialConv3d`` with normalization and
+    nonlinearity.
+
+    Args:
+        in_channels (int): Number of channels in the input tensor.
+        out_channels (int): Number of channels in the output tensor.
+        kernel_size (int or tuple): Size of the convolving kernel.
+        stride (int or float or tuple, optional, default=1):
+            Stride of the convolution.
+        padding (int or tuple, optional, default=0):
+            Zero-padding added to both sides of the input.
+        dilation (int or tuple, optional, default=1):
+            Spacing between kernel elements.
+        groups (int, optional, default=1): Number of blocked connections
+            from input channels to output channels.
+        bias (bool, optional, default=True):
+            If ``True``, adds a learnable bias to the output.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layer.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        order (str, optional, default='CNA'): Order of operations.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+            For example, a block initialized with ``order='CNA'`` will
+            do convolution first, then normalization, then nonlinearity.
+        multi_channel (bool, optional, default=False): If ``True``, use
+            different masks for different channels.
+        return_mask (bool, optional, default=True): If ``True``, the
+            forward call also returns a new mask.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 multi_channel=False, return_mask=True,
+                 apply_noise=False, order='CNA', clamp=None):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity,
+                         multi_channel, return_mask, apply_noise, order, 3,
+                         clamp)
+
+
+class _MultiOutBaseConvBlock(_BaseConvBlock):
+    r"""An abstract wrapper class that wraps a hyper convolutional layer with
+    normalization and nonlinearity. It can return multiple outputs, if some
+    layers in the block return more than one output.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity,
+                 inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1),
+                 output_scale=None, init_gain=1.0):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity,
+                         apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale, init_gain)
+        self.multiple_outputs = True
+
+    def forward(self, x, *cond_inputs, **kw_cond_inputs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+            kw_cond_inputs (dict) : Keyword conditional inputs.
+        Returns:
+            (tuple):
+              - x (tensor): Main output tensor.
+              - other_outputs (list of tensors): Other output tensors.
+        """
+        other_outputs = []
+        for layer in self.layers.values():
+            if getattr(layer, 'conditional', False):
+                x = layer(x, *cond_inputs, **kw_cond_inputs)
+            if getattr(layer, 'multiple_outputs', False):
+                x, other_output = layer(x)
+                other_outputs.append(other_output)
+            else:
+                x = layer(x)
+        return (x, *other_outputs)
+
+
+class MultiOutConv2dBlock(_MultiOutBaseConvBlock):
+    r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
+    nonlinearity. It can return multiple outputs, if some layers in the block
+    return more than one output.
+
+    Args:
+        in_channels (int): Number of channels in the input tensor.
+        out_channels (int): Number of channels in the output tensor.
+        kernel_size (int or tuple): Size of the convolving kernel.
+        stride (int or float or tuple, optional, default=1):
+            Stride of the convolution.
+        padding (int or tuple, optional, default=0):
+            Zero-padding added to both sides of the input.
+        dilation (int or tuple, optional, default=1):
+            Spacing between kernel elements.
+        groups (int, optional, default=1): Number of blocked connections
+            from input channels to output channels.
+        bias (bool, optional, default=True):
+            If ``True``, adds a learnable bias to the output.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layer.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        order (str, optional, default='CNA'): Order of operations.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+            For example, a block initialized with ``order='CNA'`` will
+            do convolution first, then normalization, then nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 nonlinearity='none', inplace_nonlinearity=False,
+                 apply_noise=False, blur=False, order='CNA', clamp=None):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         nonlinearity, inplace_nonlinearity,
+                         apply_noise, blur, order, 2, clamp)
+
+
+###############################################################################
+# BSD 3-Clause License
+#
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Author & Contact: Guilin Liu (guilinl@nvidia.com)
+###############################################################################
+class PartialConv2d(nn.Conv2d):
+    r"""Partial 2D convolution in
+    "Image inpainting for irregular holes using partial convolutions."
+    Liu et al., ECCV 2018
+    """
+
+    def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
+        # whether the mask is multi-channel or not
+        self.multi_channel = multi_channel
+        self.return_mask = return_mask
+        super(PartialConv2d, self).__init__(*args, **kwargs)
+
+        if self.multi_channel:
+            self.weight_maskUpdater = torch.ones(self.out_channels,
+                                                 self.in_channels,
+                                                 self.kernel_size[0],
+                                                 self.kernel_size[1])
+        else:
+            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
+                                                 self.kernel_size[1])
+
+        shape = self.weight_maskUpdater.shape
+        self.slide_winsize = shape[1] * shape[2] * shape[3]
+
+        self.last_size = (None, None, None, None)
+        self.update_mask = None
+        self.mask_ratio = None
+        self.partial_conv = True
+
+    def forward(self, x, mask_in=None):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            mask_in (tensor, optional, default=``None``) If not ``None``,
+                it masks the valid input region.
+        """
+        assert len(x.shape) == 4
+        if mask_in is not None or self.last_size != tuple(x.shape):
+            self.last_size = tuple(x.shape)
+
+            with torch.no_grad():
+                if self.weight_maskUpdater.type() != x.type():
+                    self.weight_maskUpdater = self.weight_maskUpdater.to(x)
+
+                if mask_in is None:
+                    # If mask is not provided, create a mask.
+                    if self.multi_channel:
+                        mask = torch.ones(x.data.shape[0],
+                                          x.data.shape[1],
+                                          x.data.shape[2],
+                                          x.data.shape[3]).to(x)
+                    else:
+                        mask = torch.ones(1, 1, x.data.shape[2],
+                                          x.data.shape[3]).to(x)
+                else:
+                    mask = mask_in
+
+                self.update_mask = F.conv2d(mask, self.weight_maskUpdater,
+                                            bias=None, stride=self.stride,
+                                            padding=self.padding,
+                                            dilation=self.dilation, groups=1)
+
+                # For mixed precision training, eps from 1e-8 to 1e-6.
+                eps = 1e-6
+                self.mask_ratio = self.slide_winsize / (self.update_mask + eps)
+                self.update_mask = torch.clamp(self.update_mask, 0, 1)
+                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
+
+        raw_out = super(PartialConv2d, self).forward(
+            torch.mul(x, mask) if mask_in is not None else x)
+
+        if self.bias is not None:
+            bias_view = self.bias.view(1, self.out_channels, 1, 1)
+            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
+            output = torch.mul(output, self.update_mask)
+        else:
+            output = torch.mul(raw_out, self.mask_ratio)
+
+        if self.return_mask:
+            return output, self.update_mask
+        else:
+            return output
+
+
+class PartialConv3d(nn.Conv3d):
+    r"""Partial 3D convolution in
+    "Image inpainting for irregular holes using partial convolutions."
+    Liu et al., ECCV 2018
+    """
+
+    def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
+        # whether the mask is multi-channel or not
+        self.multi_channel = multi_channel
+        self.return_mask = return_mask
+        super(PartialConv3d, self).__init__(*args, **kwargs)
+
+        if self.multi_channel:
+            self.weight_maskUpdater = \
+                torch.ones(self.out_channels, self.in_channels,
+                           self.kernel_size[0], self.kernel_size[1],
+                           self.kernel_size[2])
+        else:
+            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
+                                                 self.kernel_size[1],
+                                                 self.kernel_size[2])
+        self.weight_maskUpdater = self.weight_maskUpdater.to('cuda')
+
+        shape = self.weight_maskUpdater.shape
+        self.slide_winsize = shape[1] * shape[2] * shape[3] * shape[4]
+        self.partial_conv = True
+
+    def forward(self, x, mask_in=None):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            mask_in (tensor, optional, default=``None``) If not ``None``, it
+                masks the valid input region.
+        """
+        assert len(x.shape) == 5
+
+        with torch.no_grad():
+            mask = mask_in
+            update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None,
+                                   stride=self.stride, padding=self.padding,
+                                   dilation=self.dilation, groups=1)
+
+            mask_ratio = self.slide_winsize / (update_mask + 1e-8)
+            update_mask = torch.clamp(update_mask, 0, 1)
+            mask_ratio = torch.mul(mask_ratio, update_mask)
+
+        raw_out = super(PartialConv3d, self).forward(torch.mul(x, mask_in))
+
+        if self.bias is not None:
+            bias_view = self.bias.view(1, self.out_channels, 1, 1, 1)
+            output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
+            if mask_in is not None:
+                output = torch.mul(output, update_mask)
+        else:
+            output = torch.mul(raw_out, mask_ratio)
+
+        if self.return_mask:
+            return output, update_mask
+        else:
+            return output
+
+
+class Embedding2d(nn.Embedding):
+    def __init__(self, in_channels, out_channels):
+        super().__init__(in_channels, out_channels)
+
+    def forward(self, x):
+        return F.embedding(
+            x.squeeze(1).long(), self.weight, self.padding_idx, self.max_norm,
+            self.norm_type, self.scale_grad_by_freq, self.sparse).permute(0, 3, 1, 2).contiguous()
diff --git a/imaginaire/layers/misc.py b/imaginaire/layers/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..7731bd2fa855939a2866b211c33eb3ccce00c480
--- /dev/null
+++ b/imaginaire/layers/misc.py
@@ -0,0 +1,61 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+
+
+class ApplyNoise(nn.Module):
+    r"""Add Gaussian noise to the input tensor."""
+
+    def __init__(self):
+        super().__init__()
+        # scale of the noise
+        self.scale = nn.Parameter(torch.zeros(1))
+        self.conditional = True
+
+    def forward(self, x, *_args, noise=None, **_kwargs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            noise (tensor, optional, default=``None``) : Noise tensor to be
+                added to the input.
+        """
+        if noise is None:
+            sz = x.size()
+            noise = x.new_empty(sz[0], 1, *sz[2:]).normal_()
+
+        return x + self.scale * noise
+
+
+class PartialSequential(nn.Sequential):
+    r"""Sequential block for partial convolutions."""
+    def __init__(self, *modules):
+        super(PartialSequential, self).__init__(*modules)
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+        """
+        act = x[:, :-1]
+        mask = x[:, -1].unsqueeze(1)
+        for module in self:
+            act, mask = module(act, mask_in=mask)
+        return act
+
+
+class ConstantInput(nn.Module):
+    def __init__(self, channel, size=4):
+        super().__init__()
+        if isinstance(size, int):
+            h, w = size, size
+        else:
+            h, w = size
+        self.input = nn.Parameter(torch.randn(1, channel, h, w))
+
+    def forward(self):
+        return self.input
diff --git a/imaginaire/layers/non_local.py b/imaginaire/layers/non_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d1a8b36d668377ef4c9c4d897cd3b55fe0363e7
--- /dev/null
+++ b/imaginaire/layers/non_local.py
@@ -0,0 +1,88 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from imaginaire.layers import Conv2dBlock
+
+
+class NonLocal2dBlock(nn.Module):
+    r"""Self attention Layer
+
+    Args:
+        in_channels (int): Number of channels in the input tensor.
+        scale (bool, optional, default=True): If ``True``, scale the
+            output by a learnable parameter.
+        clamp (bool, optional, default=``False``): If ``True``, clamp the
+            scaling parameter to (-1, 1).
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, weight_norm_params.__dict__ will be used as
+            keyword arguments when initializing weight normalization.
+        bias (bool, optional, default=True): If ``True``, adds bias in the
+            convolutional blocks.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 scale=True,
+                 clamp=False,
+                 weight_norm_type='none',
+                 weight_norm_params=None,
+                 bias=True):
+        super(NonLocal2dBlock, self).__init__()
+        self.clamp = clamp
+        self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0
+        self.in_channels = in_channels
+        base_conv2d_block = partial(Conv2dBlock,
+                                    kernel_size=1,
+                                    stride=1,
+                                    padding=0,
+                                    weight_norm_type=weight_norm_type,
+                                    weight_norm_params=weight_norm_params,
+                                    bias=bias)
+        self.theta = base_conv2d_block(in_channels, in_channels // 8)
+        self.phi = base_conv2d_block(in_channels, in_channels // 8)
+        self.g = base_conv2d_block(in_channels, in_channels // 2)
+        self.out_conv = base_conv2d_block(in_channels // 2, in_channels)
+        self.softmax = nn.Softmax(dim=-1)
+        self.max_pool = nn.MaxPool2d(2)
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor) : input feature maps (B X C X W X H)
+        Returns:
+            (tuple):
+              - out (tensor) : self attention value + input feature
+              - attention (tensor): B x N x N (N is Width*Height)
+        """
+        n, c, h, w = x.size()
+        theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1)
+
+        phi = self.phi(x)
+        phi = self.max_pool(phi).view(n, -1, h * w // 4)
+
+        energy = torch.bmm(theta, phi)
+        attention = self.softmax(energy)
+
+        g = self.g(x)
+        g = self.max_pool(g).view(n, -1, h * w // 4)
+
+        out = torch.bmm(g, attention.permute(0, 2, 1))
+        out = out.view(n, c // 2, h, w)
+        out = self.out_conv(out)
+
+        if self.clamp:
+            out = self.gamma.clamp(-1, 1) * out + x
+        else:
+            out = self.gamma * out + x
+        return out
diff --git a/imaginaire/layers/nonlinearity.py b/imaginaire/layers/nonlinearity.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc172c74323e707e5a19f94e466c1bf0dae4418
--- /dev/null
+++ b/imaginaire/layers/nonlinearity.py
@@ -0,0 +1,65 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from imaginaire.third_party.bias_act.bias_act import FusedNonlinearity
+
+
+class ScaledLeakyReLU(nn.Module):
+    def __init__(self, negative_slope=0.2, scale=2 ** 0.5, inplace=False):
+        super().__init__()
+
+        self.negative_slope = negative_slope
+        self.scale = scale
+        self.inplace = inplace
+
+    def forward(self, x):
+        return F.leaky_relu(x, self.negative_slope, inplace=self.inplace) * self.scale
+        # return _fused_scaled_leakyrelu(x, self.negative_slope, self.inplace, self.scale)
+
+
+# @torch.jit.script
+# def _fused_scaled_leakyrelu(x: torch.Tensor, negative_slope: float, inplace: bool, scale: float):
+#     return F.leaky_relu(x, negative_slope, inplace=inplace) * scale
+
+
+def get_nonlinearity_layer(nonlinearity_type, inplace, **kwargs):
+    r"""Return a nonlinearity layer.
+
+    Args:
+        nonlinearity_type (str):
+            Type of nonlinear activation function.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace (bool): If ``True``, set ``inplace=True`` when initializing
+            the nonlinearity layer.
+    """
+    if nonlinearity_type.startswith('fused'):
+        nonlinearity = FusedNonlinearity(nonlinearity=nonlinearity_type[6:], **kwargs)
+    elif nonlinearity_type == 'relu':
+        nonlinearity = nn.ReLU(inplace=inplace)
+    elif nonlinearity_type == 'leakyrelu':
+        nonlinearity = nn.LeakyReLU(0.2, inplace=inplace)
+    elif nonlinearity_type == 'scaled_leakyrelu':
+        nonlinearity = ScaledLeakyReLU(0.2, inplace=inplace)
+        import imaginaire.config
+        if imaginaire.config.USE_JIT:
+            nonlinearity = torch.jit.script(nonlinearity)
+    elif nonlinearity_type == 'prelu':
+        nonlinearity = nn.PReLU()
+    elif nonlinearity_type == 'tanh':
+        nonlinearity = nn.Tanh()
+    elif nonlinearity_type == 'sigmoid':
+        nonlinearity = nn.Sigmoid()
+    elif nonlinearity_type.startswith('softmax'):
+        dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1
+        nonlinearity = nn.Softmax(dim=int(dim))
+    elif nonlinearity_type == 'none' or nonlinearity_type == '':
+        nonlinearity = None
+    else:
+        raise ValueError('Nonlinearity %s is not recognized' % nonlinearity_type)
+    return nonlinearity
diff --git a/imaginaire/layers/residual.py b/imaginaire/layers/residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e1bda4dd30922f694302803b7d606af7f3c0c21
--- /dev/null
+++ b/imaginaire/layers/residual.py
@@ -0,0 +1,1411 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+
+import torch
+from torch import nn
+from torch.nn import Upsample as NearestUpsample
+from torch.utils.checkpoint import checkpoint
+
+from .conv import (Conv1dBlock, Conv2dBlock, Conv3dBlock, HyperConv2dBlock,
+                   LinearBlock, MultiOutConv2dBlock, PartialConv2dBlock,
+                   PartialConv3dBlock, ModulatedConv2dBlock)
+from imaginaire.third_party.upfirdn2d.upfirdn2d import BlurUpsample
+
+
+class _BaseResBlock(nn.Module):
+    r"""An abstract class for residual blocks.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size,
+                 stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 skip_activation_norm, skip_nonlinearity,
+                 nonlinearity, inplace_nonlinearity, apply_noise,
+                 hidden_channels_equal_out_channels,
+                 order, block, learn_shortcut, clamp, output_scale,
+                 skip_block=None, blur=False, upsample_first=True, skip_weight_norm=True):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.output_scale = output_scale
+        self.upsample_first = upsample_first
+        self.stride = stride
+        self.blur = blur
+        if skip_block is None:
+            skip_block = block
+
+        if order == 'pre_act':
+            order = 'NACNAC'
+        if isinstance(bias, bool):
+            # The bias for conv_block_0, conv_block_1, and conv_block_s.
+            biases = [bias, bias, bias]
+        elif isinstance(bias, list):
+            if len(bias) == 3:
+                biases = bias
+            else:
+                raise ValueError('Bias list must be 3.')
+        else:
+            raise ValueError('Bias must be either an integer or s list.')
+        if learn_shortcut is None:
+            self.learn_shortcut = (in_channels != out_channels)
+        else:
+            self.learn_shortcut = learn_shortcut
+        if len(order) > 6 or len(order) < 5:
+            raise ValueError('order must be either 5 or 6 characters')
+        if hidden_channels_equal_out_channels:
+            hidden_channels = out_channels
+        else:
+            hidden_channels = min(in_channels, out_channels)
+
+        # Parameters.
+        residual_params = {}
+        shortcut_params = {}
+        base_params = dict(dilation=dilation,
+                           groups=groups,
+                           padding_mode=padding_mode,
+                           clamp=clamp)
+        residual_params.update(base_params)
+        residual_params.update(
+            dict(activation_norm_type=activation_norm_type,
+                 activation_norm_params=activation_norm_params,
+                 weight_norm_type=weight_norm_type,
+                 weight_norm_params=weight_norm_params,
+                 padding=padding,
+                 apply_noise=apply_noise))
+        shortcut_params.update(base_params)
+        shortcut_params.update(dict(kernel_size=1))
+        if skip_activation_norm:
+            shortcut_params.update(
+                dict(activation_norm_type=activation_norm_type,
+                     activation_norm_params=activation_norm_params,
+                     apply_noise=False))
+        if skip_weight_norm:
+            shortcut_params.update(
+                dict(weight_norm_type=weight_norm_type,
+                     weight_norm_params=weight_norm_params))
+
+        # Residual branch.
+        if order.find('A') < order.find('C') and \
+                (activation_norm_type == '' or activation_norm_type == 'none'):
+            # Nonlinearity is the first operation in the residual path.
+            # In-place nonlinearity will modify the input variable and cause
+            # backward error.
+            first_inplace = False
+        else:
+            first_inplace = inplace_nonlinearity
+
+        (first_stride, second_stride, shortcut_stride,
+         first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
+        self.conv_block_0 = block(
+            in_channels, hidden_channels,
+            kernel_size=kernel_size,
+            bias=biases[0],
+            nonlinearity=nonlinearity,
+            order=order[0:3],
+            inplace_nonlinearity=first_inplace,
+            stride=first_stride,
+            blur=first_blur,
+            **residual_params
+        )
+        self.conv_block_1 = block(
+            hidden_channels, out_channels,
+            kernel_size=kernel_size,
+            bias=biases[1],
+            nonlinearity=nonlinearity,
+            order=order[3:],
+            inplace_nonlinearity=inplace_nonlinearity,
+            stride=second_stride,
+            blur=second_blur,
+            **residual_params
+        )
+
+        # Shortcut branch.
+        if self.learn_shortcut:
+            if skip_nonlinearity:
+                skip_nonlinearity_type = nonlinearity
+            else:
+                skip_nonlinearity_type = ''
+            self.conv_block_s = skip_block(in_channels, out_channels,
+                                           bias=biases[2],
+                                           nonlinearity=skip_nonlinearity_type,
+                                           order=order[0:3],
+                                           stride=shortcut_stride,
+                                           blur=shortcut_blur,
+                                           **shortcut_params)
+        elif in_channels < out_channels:
+            if skip_nonlinearity:
+                skip_nonlinearity_type = nonlinearity
+            else:
+                skip_nonlinearity_type = ''
+            self.conv_block_s = skip_block(in_channels,
+                                           out_channels - in_channels,
+                                           bias=biases[2],
+                                           nonlinearity=skip_nonlinearity_type,
+                                           order=order[0:3],
+                                           stride=shortcut_stride,
+                                           blur=shortcut_blur,
+                                           **shortcut_params)
+
+        # Whether this block expects conditional inputs.
+        self.conditional = \
+            getattr(self.conv_block_0, 'conditional', False) or \
+            getattr(self.conv_block_1, 'conditional', False)
+
+    def _get_stride_blur(self):
+        if self.stride > 1:
+            # Downsampling.
+            first_stride, second_stride = 1, self.stride
+            first_blur, second_blur = False, self.blur
+            shortcut_stride = self.stride
+            shortcut_blur = self.blur
+            self.upsample = None
+        elif self.stride < 1:
+            # Upsampling.
+            first_stride, second_stride = self.stride, 1
+            first_blur, second_blur = self.blur, False
+            shortcut_blur = False
+            shortcut_stride = 1
+            if self.blur:
+                # The shortcut branch uses blur_upsample + stride-1 conv
+                self.upsample = BlurUpsample()
+            else:
+                shortcut_stride = self.stride
+                self.upsample = nn.Upsample(scale_factor=2)
+        else:
+            first_stride = second_stride = 1
+            first_blur = second_blur = False
+            shortcut_stride = 1
+            shortcut_blur = False
+            self.upsample = None
+        return (first_stride, second_stride, shortcut_stride,
+                first_blur, second_blur, shortcut_blur)
+
+    def conv_blocks(
+            self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
+    ):
+        r"""Returns the output of the residual branch.
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+            kw_cond_inputs (dict) : Keyword conditional inputs.
+        Returns:
+            dx (tensor): Output tensor.
+        """
+        if separate_cond:
+            dx = self.conv_block_0(x, cond_inputs[0],
+                                   **kw_cond_inputs.get('kwargs_0', {}))
+            dx = self.conv_block_1(dx, cond_inputs[1],
+                                   **kw_cond_inputs.get('kwargs_1', {}))
+        else:
+            dx = self.conv_block_0(x, *cond_inputs, **kw_cond_inputs)
+            dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
+        return dx
+
+    def forward(self, x, *cond_inputs, do_checkpoint=False, separate_cond=False,
+                **kw_cond_inputs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+            do_checkpoint (bool, optional, default=``False``) If ``True``,
+                trade compute for memory by checkpointing the model.
+            kw_cond_inputs (dict) : Keyword conditional inputs.
+        Returns:
+            output (tensor): Output tensor.
+        """
+        if do_checkpoint:
+            dx = checkpoint(self.conv_blocks, x, *cond_inputs,
+                            separate_cond=separate_cond, **kw_cond_inputs)
+        else:
+            dx = self.conv_blocks(x, *cond_inputs,
+                                  separate_cond=separate_cond, **kw_cond_inputs)
+
+        if self.upsample_first and self.upsample is not None:
+            x = self.upsample(x)
+        if self.learn_shortcut:
+            if separate_cond:
+                x_shortcut = self.conv_block_s(
+                    x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
+                )
+            else:
+                x_shortcut = self.conv_block_s(
+                    x, *cond_inputs, **kw_cond_inputs
+                )
+        elif self.in_channels < self.out_channels:
+            if separate_cond:
+                x_shortcut_pad = self.conv_block_s(
+                    x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
+                )
+            else:
+                x_shortcut_pad = self.conv_block_s(
+                    x, *cond_inputs, **kw_cond_inputs
+                )
+            x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
+        elif self.in_channels > self.out_channels:
+            x_shortcut = x[:, :self.out_channels, :, :]
+        else:
+            x_shortcut = x
+        if not self.upsample_first and self.upsample is not None:
+            x_shortcut = self.upsample(x_shortcut)
+
+        output = x_shortcut + dx
+        return self.output_scale * output
+
+    def extra_repr(self):
+        s = 'output_scale={output_scale}'
+        return s.format(**self.__dict__)
+
+
+class ModulatedRes2dBlock(_BaseResBlock):
+    def __init__(self, in_channels, out_channels, style_dim, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=True, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1,
+                 demodulate=True, eps=1e-8):
+        block = functools.partial(ModulatedConv2dBlock,
+                                  style_dim=style_dim,
+                                  demodulate=demodulate, eps=eps)
+        skip_block = Conv2dBlock
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, block,
+                         learn_shortcut, clamp, output_scale, skip_block=skip_block)
+
+    def conv_blocks(self, x, *cond_inputs, **kw_cond_inputs):
+        assert len(list(cond_inputs)) == 2
+        dx = self.conv_block_0(x, cond_inputs[0], **kw_cond_inputs)
+        dx = self.conv_block_1(dx, cond_inputs[1], **kw_cond_inputs)
+        return dx
+
+
+class ResLinearBlock(_BaseResBlock):
+    r"""Residual block with full-connected layers.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, add
+            Gaussian noise with learnable magnitude after the
+            fully-connected layer.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: fully-connected,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, bias=True,
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=None, clamp=None,
+                 output_scale=1):
+        super().__init__(in_channels, out_channels, None, 1, None, None,
+                         None, bias, None, weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, LinearBlock,
+                         learn_shortcut, clamp, output_scale)
+
+
+class Res1dBlock(_BaseResBlock):
+    r"""Residual block for 1D input.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=None, clamp=None,
+                 output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, Conv1dBlock,
+                         learn_shortcut, clamp, output_scale)
+
+
+class Res2dBlock(_BaseResBlock):
+    r"""Residual block for 2D input.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 skip_weight_norm=True,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=None, clamp=None,
+                 output_scale=1, blur=False, upsample_first=True):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, Conv2dBlock,
+                         learn_shortcut, clamp, output_scale, blur=blur,
+                         upsample_first=upsample_first,
+                         skip_weight_norm=skip_weight_norm)
+
+
+class Res3dBlock(_BaseResBlock):
+    r"""Residual block for 3D input.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=None, clamp=None,
+                 output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, Conv3dBlock,
+                         learn_shortcut, clamp, output_scale)
+
+
+class _BaseHyperResBlock(_BaseResBlock):
+    r"""An abstract class for hyper residual blocks.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size,
+                 stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 skip_activation_norm, skip_nonlinearity,
+                 nonlinearity, inplace_nonlinearity, apply_noise,
+                 hidden_channels_equal_out_channels,
+                 order, is_hyper_conv, is_hyper_norm, block, learn_shortcut,
+                 clamp=None, output_scale=1):
+        block = functools.partial(block,
+                                  is_hyper_conv=is_hyper_conv,
+                                  is_hyper_norm=is_hyper_norm)
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, block,
+                         learn_shortcut, clamp, output_scale)
+
+    def forward(self, x, *cond_inputs, conv_weights=(None,) * 3,
+                norm_weights=(None,) * 3, **kw_cond_inputs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+            conv_weights (list of tensors): Convolution weights for
+                three convolutional layers respectively.
+            norm_weights (list of tensors): Normalization weights for
+                three convolutional layers respectively.
+            kw_cond_inputs (dict) : Keyword conditional inputs.
+        Returns:
+            output (tensor): Output tensor.
+        """
+        dx = self.conv_block_0(x, *cond_inputs, conv_weights=conv_weights[0],
+                               norm_weights=norm_weights[0])
+        dx = self.conv_block_1(dx, *cond_inputs, conv_weights=conv_weights[1],
+                               norm_weights=norm_weights[1])
+        if self.learn_shortcut:
+            x_shortcut = self.conv_block_s(x, *cond_inputs,
+                                           conv_weights=conv_weights[2],
+                                           norm_weights=norm_weights[2])
+        else:
+            x_shortcut = x
+        output = x_shortcut + dx
+        return self.output_scale * output
+
+
+class HyperRes2dBlock(_BaseHyperResBlock):
+    r"""Hyper residual block for 2D input.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        is_hyper_conv (bool, optional, default=False): If ``True``, use
+            ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
+        is_hyper_norm (bool, optional, default=False): If ``True``, use
+            hyper normalizations.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='', weight_norm_params=None,
+                 activation_norm_type='', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', is_hyper_conv=False, is_hyper_norm=False,
+                 learn_shortcut=None, clamp=None, output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride, padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity,
+                         nonlinearity, inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels,
+                         order, is_hyper_conv, is_hyper_norm,
+                         HyperConv2dBlock, learn_shortcut, clamp, output_scale)
+
+
+class _BaseDownResBlock(_BaseResBlock):
+    r"""An abstract class for residual blocks with downsampling.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size,
+                 stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 skip_activation_norm, skip_nonlinearity,
+                 nonlinearity, inplace_nonlinearity,
+                 apply_noise, hidden_channels_equal_out_channels,
+                 order, block, pooling, down_factor, learn_shortcut,
+                 clamp=None, output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride, padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, block,
+                         learn_shortcut, clamp, output_scale)
+        self.pooling = pooling(down_factor)
+
+    def forward(self, x, *cond_inputs):
+        r"""
+
+        Args:
+            x (tensor) : Input tensor.
+            cond_inputs (list of tensors) : conditional input.
+        Returns:
+            output (tensor) : Output tensor.
+        """
+        dx = self.conv_block_0(x, *cond_inputs)
+        dx = self.conv_block_1(dx, *cond_inputs)
+        dx = self.pooling(dx)
+        if self.learn_shortcut:
+            x_shortcut = self.conv_block_s(x, *cond_inputs)
+        else:
+            x_shortcut = x
+        x_shortcut = self.pooling(x_shortcut)
+        output = x_shortcut + dx
+        return self.output_scale * output
+
+
+class DownRes2dBlock(_BaseDownResBlock):
+    r"""Residual block for 2D input with downsampling.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        pooling (class, optional, default=nn.AvgPool2d): Pytorch pooling
+            layer to be used.
+        down_factor (int, optional, default=2): Downsampling factor.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', pooling=nn.AvgPool2d, down_factor=2,
+                 learn_shortcut=None, clamp=None, output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride, padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity,
+                         nonlinearity, inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels,
+                         order, Conv2dBlock, pooling,
+                         down_factor, learn_shortcut, clamp, output_scale)
+
+
+class _BaseUpResBlock(_BaseResBlock):
+    r"""An abstract class for residual blocks with upsampling.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size,
+                 stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 skip_activation_norm, skip_nonlinearity,
+                 nonlinearity, inplace_nonlinearity,
+                 apply_noise, hidden_channels_equal_out_channels,
+                 order, block, upsample, up_factor, learn_shortcut, clamp=None,
+                 output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride, padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, block,
+                         learn_shortcut, clamp, output_scale)
+        self.order = order
+        self.upsample = upsample(scale_factor=up_factor)
+
+    def _get_stride_blur(self):
+        # Upsampling.
+        first_stride, second_stride = self.stride, 1
+        first_blur, second_blur = self.blur, False
+        shortcut_blur = False
+        shortcut_stride = 1
+        # if self.upsample == 'blur_deconv':
+
+        if self.blur:
+            # The shortcut branch uses blur_upsample + stride-1 conv
+            self.upsample = BlurUpsample()
+        else:
+            shortcut_stride = self.stride
+            self.upsample = nn.Upsample(scale_factor=2)
+
+        return (first_stride, second_stride, shortcut_stride,
+                first_blur, second_blur, shortcut_blur)
+
+    def forward(self, x, *cond_inputs):
+        r"""Implementation of the up residual block forward function.
+        If the order is 'NAC' for the first residual block, we will first
+        do the activation norm and nonlinearity, in the original resolution.
+        We will then upsample the activation map to a higher resolution. We
+        then do the convolution.
+        It is is other orders, then we first do the whole processing and
+        then upsample.
+
+        Args:
+            x (tensor) : Input tensor.
+            cond_inputs (list of tensors) : Conditional input.
+        Returns:
+            output (tensor) : Output tensor.
+        """
+        # In this particular upsample residual block operation, we first
+        # upsample the skip connection.
+        if self.learn_shortcut:
+            x_shortcut = self.upsample(x)
+            x_shortcut = self.conv_block_s(x_shortcut, *cond_inputs)
+        else:
+            x_shortcut = self.upsample(x)
+
+        if self.order[0:3] == 'NAC':
+            for ix, layer in enumerate(self.conv_block_0.layers.values()):
+                if getattr(layer, 'conditional', False):
+                    x = layer(x, *cond_inputs)
+                else:
+                    x = layer(x)
+                if ix == 1:
+                    x = self.upsample(x)
+        else:
+            x = self.conv_block_0(x, *cond_inputs)
+            x = self.upsample(x)
+        x = self.conv_block_1(x, *cond_inputs)
+
+        output = x_shortcut + x
+        return self.output_scale * output
+
+
+class UpRes2dBlock(_BaseUpResBlock):
+    r"""Residual block for 2D input with downsampling.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        upsample (class, optional, default=NearestUpsample): PPytorch
+            upsampling layer to be used.
+        up_factor (int, optional, default=2): Upsampling factor.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', upsample=NearestUpsample, up_factor=2,
+                 learn_shortcut=None, clamp=None, output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride, padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity,
+                         nonlinearity, inplace_nonlinearity,
+                         apply_noise, hidden_channels_equal_out_channels,
+                         order, Conv2dBlock,
+                         upsample, up_factor, learn_shortcut, clamp,
+                         output_scale)
+
+
+class _BasePartialResBlock(_BaseResBlock):
+    r"""An abstract class for residual blocks with partial convolution.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size,
+                 stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 skip_activation_norm, skip_nonlinearity,
+                 nonlinearity, inplace_nonlinearity,
+                 multi_channel, return_mask,
+                 apply_noise, hidden_channels_equal_out_channels,
+                 order, block, learn_shortcut, clamp=None, output_scale=1):
+        block = functools.partial(block,
+                                  multi_channel=multi_channel,
+                                  return_mask=return_mask)
+        self.partial_conv = True
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, block,
+                         learn_shortcut, clamp, output_scale)
+
+    def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+            mask_in (tensor, optional, default=``None``) If not ``None``,
+                it masks the valid input region.
+            kw_cond_inputs (dict) : Keyword conditional inputs.
+        Returns:
+            (tuple):
+              - output (tensor): Output tensor.
+              - mask_out (tensor, optional): Masks the valid output region.
+        """
+        if self.conv_block_0.layers.conv.return_mask:
+            dx, mask_out = self.conv_block_0(x, *cond_inputs,
+                                             mask_in=mask_in, **kw_cond_inputs)
+            dx, mask_out = self.conv_block_1(dx, *cond_inputs,
+                                             mask_in=mask_out, **kw_cond_inputs)
+        else:
+            dx = self.conv_block_0(x, *cond_inputs,
+                                   mask_in=mask_in, **kw_cond_inputs)
+            dx = self.conv_block_1(dx, *cond_inputs,
+                                   mask_in=mask_in, **kw_cond_inputs)
+            mask_out = None
+
+        if self.learn_shortcut:
+            x_shortcut = self.conv_block_s(x, mask_in=mask_in, *cond_inputs,
+                                           **kw_cond_inputs)
+            if type(x_shortcut) == tuple:
+                x_shortcut, _ = x_shortcut
+        else:
+            x_shortcut = x
+        output = x_shortcut + dx
+
+        if mask_out is not None:
+            return output, mask_out
+        return self.output_scale * output
+
+
+class PartialRes2dBlock(_BasePartialResBlock):
+    r"""Residual block for 2D input with partial convolution.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 multi_channel=False, return_mask=True,
+                 apply_noise=False,
+                 hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=None, clamp=None,
+                 output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride, padding, dilation, groups, bias,
+                         padding_mode, weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, multi_channel, return_mask,
+                         apply_noise, hidden_channels_equal_out_channels,
+                         order, PartialConv2dBlock, learn_shortcut, clamp,
+                         output_scale)
+
+
+class PartialRes3dBlock(_BasePartialResBlock):
+    r"""Residual block for 3D input with partial convolution.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 multi_channel=False, return_mask=True,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=None, clamp=None,
+                 output_scale=1):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride, padding, dilation, groups, bias,
+                         padding_mode, weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity,
+                         nonlinearity, inplace_nonlinearity, multi_channel,
+                         return_mask, apply_noise,
+                         hidden_channels_equal_out_channels,
+                         order, PartialConv3dBlock, learn_shortcut, clamp,
+                         output_scale)
+
+
+class _BaseMultiOutResBlock(_BaseResBlock):
+    r"""An abstract class for residual blocks that can returns multiple outputs.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size,
+                 stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 skip_activation_norm, skip_nonlinearity,
+                 nonlinearity, inplace_nonlinearity,
+                 apply_noise, hidden_channels_equal_out_channels,
+                 order, block, learn_shortcut, clamp=None, output_scale=1,
+                 blur=False, upsample_first=True):
+        self.multiple_outputs = True
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, block,
+                         learn_shortcut, clamp, output_scale, blur=blur,
+                         upsample_first=upsample_first)
+
+    def forward(self, x, *cond_inputs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+        Returns:
+            (tuple):
+              - output (tensor): Output tensor.
+              - aux_outputs_0 (tensor): Auxiliary output of the first block.
+              - aux_outputs_1 (tensor): Auxiliary output of the second block.
+        """
+        dx, aux_outputs_0 = self.conv_block_0(x, *cond_inputs)
+        dx, aux_outputs_1 = self.conv_block_1(dx, *cond_inputs)
+        if self.learn_shortcut:
+            # We are not using the auxiliary outputs of self.conv_block_s.
+            x_shortcut, _ = self.conv_block_s(x, *cond_inputs)
+        else:
+            x_shortcut = x
+        output = x_shortcut + dx
+        return self.output_scale * output, aux_outputs_0, aux_outputs_1
+
+
+class MultiOutRes2dBlock(_BaseMultiOutResBlock):
+    r"""Residual block for 2D input. It can return multiple outputs, if some
+    layers in the block return more than one output.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=None, clamp=None,
+                 output_scale=1, blur=False, upsample_first=True):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order,
+                         MultiOutConv2dBlock, learn_shortcut, clamp,
+                         output_scale, blur=blur, upsample_first=upsample_first)
diff --git a/imaginaire/layers/residual_deep.py b/imaginaire/layers/residual_deep.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0bbcd497f4689bed4faf20e8e47c0fc4e282812
--- /dev/null
+++ b/imaginaire/layers/residual_deep.py
@@ -0,0 +1,346 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample
+from .conv import Conv2dBlock
+
+
+class _BaseDeepResBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size,
+                 stride, padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 skip_activation_norm, skip_nonlinearity,
+                 nonlinearity, inplace_nonlinearity, apply_noise,
+                 hidden_channels_equal_out_channels,
+                 order, block, learn_shortcut, output_scale, skip_block=None,
+                 blur=True, border_free=True, resample_first=True,
+                 skip_weight_norm=True, hidden_channel_ratio=4):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.output_scale = output_scale
+        self.resample_first = resample_first
+        self.stride = stride
+        self.blur = blur
+        self.border_free = border_free
+        assert not border_free
+        if skip_block is None:
+            skip_block = block
+
+        if order == 'pre_act':
+            order = 'NACNAC'
+        if isinstance(bias, bool):
+            # The bias for conv_block_0, conv_block_1, and conv_block_s.
+            biases = [bias, bias, bias]
+        elif isinstance(bias, list):
+            if len(bias) == 3:
+                biases = bias
+            else:
+                raise ValueError('Bias list must be 3.')
+        else:
+            raise ValueError('Bias must be either an integer or s list.')
+        self.learn_shortcut = learn_shortcut
+        if len(order) > 6 or len(order) < 5:
+            raise ValueError('order must be either 5 or 6 characters')
+        hidden_channels = in_channels // hidden_channel_ratio
+
+        # Parameters.
+        residual_params = {}
+        shortcut_params = {}
+        base_params = dict(dilation=dilation,
+                           groups=groups,
+                           padding_mode=padding_mode)
+        residual_params.update(base_params)
+        residual_params.update(
+            dict(activation_norm_type=activation_norm_type,
+                 activation_norm_params=activation_norm_params,
+                 weight_norm_type=weight_norm_type,
+                 weight_norm_params=weight_norm_params,
+                 apply_noise=apply_noise)
+        )
+        shortcut_params.update(base_params)
+        shortcut_params.update(dict(kernel_size=1))
+        if skip_activation_norm:
+            shortcut_params.update(
+                dict(activation_norm_type=activation_norm_type,
+                     activation_norm_params=activation_norm_params,
+                     apply_noise=False))
+        if skip_weight_norm:
+            shortcut_params.update(
+                dict(weight_norm_type=weight_norm_type,
+                     weight_norm_params=weight_norm_params))
+
+        # Residual branch.
+        if order.find('A') < order.find('C') and \
+                (activation_norm_type == '' or activation_norm_type == 'none'):
+            # Nonlinearity is the first operation in the residual path.
+            # In-place nonlinearity will modify the input variable and cause
+            # backward error.
+            first_inplace = False
+        else:
+            first_inplace = inplace_nonlinearity
+
+        (first_stride, second_stride, shortcut_stride,
+         first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
+
+        self.conv_block_1x1_in = block(
+            in_channels, hidden_channels,
+            1, 1, 0,
+            bias=biases[0],
+            nonlinearity=nonlinearity,
+            order=order[0:3],
+            inplace_nonlinearity=first_inplace,
+            **residual_params
+        )
+
+        self.conv_block_0 = block(
+            hidden_channels, hidden_channels,
+            kernel_size=2 if self.border_free and first_stride < 1 else
+            kernel_size,
+            padding=padding,
+            bias=biases[0],
+            nonlinearity=nonlinearity,
+            order=order[0:3],
+            inplace_nonlinearity=inplace_nonlinearity,
+            stride=first_stride,
+            blur=first_blur,
+            **residual_params
+        )
+        self.conv_block_1 = block(
+            hidden_channels, hidden_channels,
+            kernel_size=kernel_size,
+            padding=padding,
+            bias=biases[1],
+            nonlinearity=nonlinearity,
+            order=order[3:],
+            inplace_nonlinearity=inplace_nonlinearity,
+            stride=second_stride,
+            blur=second_blur,
+            **residual_params
+        )
+
+        self.conv_block_1x1_out = block(
+            hidden_channels, out_channels,
+            1, 1, 0,
+            bias=biases[1],
+            nonlinearity=nonlinearity,
+            order=order[0:3],
+            inplace_nonlinearity=inplace_nonlinearity,
+            **residual_params
+        )
+
+        # Shortcut branch.
+        if self.learn_shortcut:
+            if skip_nonlinearity:
+                skip_nonlinearity_type = nonlinearity
+            else:
+                skip_nonlinearity_type = ''
+            self.conv_block_s = skip_block(in_channels, out_channels,
+                                           bias=biases[2],
+                                           nonlinearity=skip_nonlinearity_type,
+                                           order=order[0:3],
+                                           stride=shortcut_stride,
+                                           blur=shortcut_blur,
+                                           **shortcut_params)
+        elif in_channels < out_channels:
+            if skip_nonlinearity:
+                skip_nonlinearity_type = nonlinearity
+            else:
+                skip_nonlinearity_type = ''
+            self.conv_block_s = skip_block(in_channels,
+                                           out_channels - in_channels,
+                                           bias=biases[2],
+                                           nonlinearity=skip_nonlinearity_type,
+                                           order=order[0:3],
+                                           stride=shortcut_stride,
+                                           blur=shortcut_blur,
+                                           **shortcut_params)
+
+        # Whether this block expects conditional inputs.
+        self.conditional = \
+            getattr(self.conv_block_0, 'conditional', False) or \
+            getattr(self.conv_block_1, 'conditional', False) or \
+            getattr(self.conv_block_1x1_in, 'conditional', False) or \
+            getattr(self.conv_block_1x1_out, 'conditional', False)
+
+    def _get_stride_blur(self):
+        if self.stride > 1:
+            # Downsampling.
+            first_stride, second_stride = 1, self.stride
+            first_blur, second_blur = False, self.blur
+            shortcut_blur = False
+            shortcut_stride = 1
+            if self.blur:
+                # The shortcut branch uses blur_downsample + stride-1 conv
+                if self.border_free:
+                    self.resample = nn.AvgPool2d(2)
+                else:
+                    self.resample = BlurDownsample()
+            else:
+                shortcut_stride = self.stride
+                self.resample = nn.AvgPool2d(2)
+        elif self.stride < 1:
+            # Upsampling.
+            first_stride, second_stride = self.stride, 1
+            first_blur, second_blur = self.blur, False
+            shortcut_blur = False
+            shortcut_stride = 1
+            if self.blur:
+                # The shortcut branch uses blur_upsample + stride-1 conv
+                if self.border_free:
+                    self.resample = nn.Upsample(scale_factor=2,
+                                                mode='bilinear')
+                else:
+                    self.resample = BlurUpsample()
+            else:
+                shortcut_stride = self.stride
+                self.resample = nn.Upsample(scale_factor=2)
+        else:
+            first_stride = second_stride = 1
+            first_blur = second_blur = False
+            shortcut_stride = 1
+            shortcut_blur = False
+            self.resample = None
+        return (first_stride, second_stride, shortcut_stride,
+                first_blur, second_blur, shortcut_blur)
+
+    def conv_blocks(
+            self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
+    ):
+        if separate_cond:
+            assert len(list(cond_inputs)) == 4
+            dx = self.conv_block_1x1_in(x, cond_inputs[0],
+                                        **kw_cond_inputs.get('kwargs_0', {}))
+            dx = self.conv_block_0(dx, cond_inputs[1],
+                                   **kw_cond_inputs.get('kwargs_1', {}))
+            dx = self.conv_block_1(dx, cond_inputs[2],
+                                   **kw_cond_inputs.get('kwargs_2', {}))
+            dx = self.conv_block_1x1_out(dx, cond_inputs[3],
+                                         **kw_cond_inputs.get('kwargs_3', {}))
+        else:
+            dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs)
+            dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs)
+            dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
+            dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs)
+        return dx
+
+    def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs):
+        if do_checkpoint:
+            dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs)
+        else:
+            dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs)
+
+        if self.resample_first and self.resample is not None:
+            x = self.resample(x)
+        if self.learn_shortcut:
+            x_shortcut = self.conv_block_s(
+                x, *cond_inputs, **kw_cond_inputs
+            )
+        elif self.in_channels < self.out_channels:
+            x_shortcut_pad = self.conv_block_s(
+                x, *cond_inputs, **kw_cond_inputs
+            )
+            x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
+        elif self.in_channels > self.out_channels:
+            x_shortcut = x[:, :self.out_channels, :, :]
+        else:
+            x_shortcut = x
+        if not self.resample_first and self.resample is not None:
+            x_shortcut = self.resample(x_shortcut)
+
+        output = x_shortcut + dx
+        return self.output_scale * output
+
+    def extra_repr(self):
+        s = 'output_scale={output_scale}'
+        return s.format(**self.__dict__)
+
+
+class DeepRes2dBlock(_BaseDeepResBlock):
+    r"""Residual block for 2D input.
+
+    Args:
+        in_channels (int) : Number of channels in the input tensor.
+        out_channels (int) : Number of channels in the output tensor.
+        kernel_size (int, optional, default=3): Kernel size for the
+            convolutional filters in the residual link.
+        padding (int, optional, default=1): Padding size.
+        dilation (int, optional, default=1): Dilation factor.
+        groups (int, optional, default=1): Number of convolutional/linear
+            groups.
+        padding_mode (string, optional, default='zeros'): Type of padding:
+            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        weight_norm_type (str, optional, default='none'):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        weight_norm_params (obj, optional, default=None):
+            Parameters of weight normalization.
+            If not ``None``, ``weight_norm_params.__dict__`` will be used as
+            keyword arguments when initializing weight normalization.
+        activation_norm_type (str, optional, default='none'):
+            Type of activation normalization.
+            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+        activation_norm_params (obj, optional, default=None):
+            Parameters of activation normalization.
+            If not ``None``, ``activation_norm_params.__dict__`` will be used as
+            keyword arguments when initializing activation normalization.
+        skip_activation_norm (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies activation norm to the
+            learned shortcut connection.
+        skip_nonlinearity (bool, optional, default=True): If ``True`` and
+            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+            learned shortcut connection.
+        nonlinearity (str, optional, default='none'):
+            Type of nonlinear activation function in the residual link.
+            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+        inplace_nonlinearity (bool, optional, default=False): If ``True``,
+            set ``inplace=True`` when initializing the nonlinearity layers.
+        apply_noise (bool, optional, default=False): If ``True``, adds
+            Gaussian noise with learnable magnitude to the convolution output.
+        hidden_channels_equal_out_channels (bool, optional, default=False):
+            If ``True``, set the hidden channel number to be equal to the
+            output channel number. If ``False``, the hidden channel number
+            equals to the smaller of the input channel number and the
+            output channel number.
+        order (str, optional, default='CNACNA'): Order of operations
+            in the residual link.
+            ``'C'``: convolution,
+            ``'N'``: normalization,
+            ``'A'``: nonlinear activation.
+        learn_shortcut (bool, optional, default=False): If ``True``, always use
+            a convolutional shortcut instead of an identity one, otherwise only
+            use a convolutional one if input and output have different number of
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3,
+                 stride=1, padding=1, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 weight_norm_type='none', weight_norm_params=None,
+                 activation_norm_type='none', activation_norm_params=None,
+                 skip_activation_norm=True, skip_nonlinearity=False,
+                 skip_weight_norm=True,
+                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
+                 apply_noise=False, hidden_channels_equal_out_channels=False,
+                 order='CNACNA', learn_shortcut=False, output_scale=1,
+                 blur=True, resample_first=True, border_free=False):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias, padding_mode,
+                         weight_norm_type, weight_norm_params,
+                         activation_norm_type, activation_norm_params,
+                         skip_activation_norm, skip_nonlinearity, nonlinearity,
+                         inplace_nonlinearity, apply_noise,
+                         hidden_channels_equal_out_channels, order, Conv2dBlock,
+                         learn_shortcut, output_scale, blur=blur,
+                         resample_first=resample_first, border_free=border_free,
+                         skip_weight_norm=skip_weight_norm)
diff --git a/imaginaire/layers/vit.py b/imaginaire/layers/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd0d039d715444efc3c5a1e9889330bfa5b4c4f
--- /dev/null
+++ b/imaginaire/layers/vit.py
@@ -0,0 +1,204 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+
+from .misc import ApplyNoise
+from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
+
+
+class ViT2dBlock(nn.Module):
+    r"""An abstract wrapper class that wraps a torch convolution or linear layer
+    with normalization and nonlinearity.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride,
+                 padding, dilation, groups, bias, padding_mode,
+                 weight_norm_type, weight_norm_params,
+                 activation_norm_type, activation_norm_params,
+                 nonlinearity, inplace_nonlinearity,
+                 apply_noise, blur, order, input_dim, clamp,
+                 blur_kernel=(1, 3, 3, 1), output_scale=None,
+                 init_gain=1.0):
+        super().__init__()
+        from .nonlinearity import get_nonlinearity_layer
+        from .weight_norm import get_weight_norm_layer
+        from .activation_norm import get_activation_norm_layer
+        self.weight_norm_type = weight_norm_type
+        self.stride = stride
+        self.clamp = clamp
+        self.init_gain = init_gain
+
+        # Nonlinearity layer.
+        if 'fused' in nonlinearity:
+            # Fusing nonlinearity with bias.
+            lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
+            conv_before_nonlinearity = order.find('C') < order.find('A')
+            if conv_before_nonlinearity:
+                assert bias
+                bias = False
+            channel = out_channels if conv_before_nonlinearity else in_channels
+            nonlinearity_layer = get_nonlinearity_layer(
+                nonlinearity, inplace=inplace_nonlinearity,
+                num_channels=channel, lr_mul=lr_mul)
+        else:
+            nonlinearity_layer = get_nonlinearity_layer(
+                nonlinearity, inplace=inplace_nonlinearity)
+
+        # Noise injection layer.
+        if apply_noise:
+            order = order.replace('C', 'CG')
+            noise_layer = ApplyNoise()
+        else:
+            noise_layer = None
+
+        # Convolutional layer.
+        if blur:
+            if stride == 2:
+                # Blur - Conv - Noise - Activate
+                p = (len(blur_kernel) - 2) + (kernel_size - 1)
+                pad0, pad1 = (p + 1) // 2, p // 2
+                padding = 0
+                blur_layer = Blur(
+                    blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
+                )
+                order = order.replace('C', 'BC')
+            elif stride == 0.5:
+                # Conv - Blur - Noise - Activate
+                padding = 0
+                p = (len(blur_kernel) - 2) - (kernel_size - 1)
+                pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
+                blur_layer = Blur(
+                    blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
+                )
+                order = order.replace('C', 'CB')
+            elif stride == 1:
+                # No blur for now
+                blur_layer = nn.Identity()
+            else:
+                raise NotImplementedError
+        else:
+            blur_layer = nn.Identity()
+
+        if weight_norm_params is None:
+            weight_norm_params = SimpleNamespace()
+        weight_norm = get_weight_norm_layer(
+            weight_norm_type, **vars(weight_norm_params))
+        conv_layer = weight_norm(self._get_conv_layer(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            groups, bias, padding_mode, input_dim))
+
+        # Normalization layer.
+        conv_before_norm = order.find('C') < order.find('N')
+        norm_channels = out_channels if conv_before_norm else in_channels
+        if activation_norm_params is None:
+            activation_norm_params = SimpleNamespace()
+        activation_norm_layer = get_activation_norm_layer(
+            norm_channels,
+            activation_norm_type,
+            input_dim,
+            **vars(activation_norm_params))
+
+        # Mapping from operation names to layers.
+        mappings = {'C': {'conv': conv_layer},
+                    'N': {'norm': activation_norm_layer},
+                    'A': {'nonlinearity': nonlinearity_layer}}
+        mappings.update({'B': {'blur': blur_layer}})
+        mappings.update({'G': {'noise': noise_layer}})
+
+        # All layers in order.
+        self.layers = nn.ModuleDict()
+        for op in order:
+            if list(mappings[op].values())[0] is not None:
+                self.layers.update(mappings[op])
+
+        # Whether this block expects conditional inputs.
+        self.conditional = \
+            getattr(conv_layer, 'conditional', False) or \
+            getattr(activation_norm_layer, 'conditional', False)
+
+        if output_scale is not None:
+            self.output_scale = nn.Parameter(torch.tensor(output_scale))
+        else:
+            self.register_parameter("output_scale", None)
+
+    def forward(self, x, *cond_inputs, **kw_cond_inputs):
+        r"""
+
+        Args:
+            x (tensor): Input tensor.
+            cond_inputs (list of tensors) : Conditional input tensors.
+            kw_cond_inputs (dict) : Keyword conditional inputs.
+        """
+        for key, layer in self.layers.items():
+            if getattr(layer, 'conditional', False):
+                # Layers that require conditional inputs.
+                x = layer(x, *cond_inputs, **kw_cond_inputs)
+            else:
+                x = layer(x)
+            if self.clamp is not None and isinstance(layer, nn.Conv2d):
+                x.clamp_(max=self.clamp)
+            if key == 'conv':
+                if self.output_scale is not None:
+                    x = x * self.output_scale
+        return x
+
+    def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+                        padding, dilation, groups, bias, padding_mode,
+                        input_dim):
+        # Returns the convolutional layer.
+        if input_dim == 0:
+            layer = nn.Linear(in_channels, out_channels, bias)
+        else:
+            if stride < 1:  # Fractionally-strided convolution.
+                padding_mode = 'zeros'
+                assert padding == 0
+                layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
+                stride = round(1 / stride)
+            else:
+                layer_type = getattr(nn, f'Conv{input_dim}d')
+            layer = layer_type(
+                in_channels, out_channels, kernel_size, stride, padding,
+                dilation=dilation, groups=groups, bias=bias,
+                padding_mode=padding_mode
+            )
+
+        return layer
+
+    def __repr__(self):
+        main_str = self._get_name() + '('
+        child_lines = []
+        for name, layer in self.layers.items():
+            mod_str = repr(layer)
+            if name == 'conv' and self.weight_norm_type != 'none' and \
+                    self.weight_norm_type != '':
+                mod_str = mod_str[:-1] + \
+                          ', weight_norm={}'.format(self.weight_norm_type) + ')'
+            if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
+                mod_str = mod_str[:-1] + \
+                          ', lr_mul={}'.format(layer.base_lr_mul) + ')'
+            mod_str = self._addindent(mod_str, 2)
+            child_lines.append(mod_str)
+        if len(child_lines) == 1:
+            main_str += child_lines[0]
+        else:
+            main_str += '\n  ' + '\n  '.join(child_lines) + '\n'
+
+        main_str += ')'
+        return main_str
+
+    @staticmethod
+    def _addindent(s_, numSpaces):
+        s = s_.split('\n')
+        # don't do anything for single-line stuff
+        if len(s) == 1:
+            return s_
+        first = s.pop(0)
+        s = [(numSpaces * ' ') + line for line in s]
+        s = '\n'.join(s)
+        s = first + '\n' + s
+        return s
diff --git a/imaginaire/layers/weight_norm.py b/imaginaire/layers/weight_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e15ca2d21cea70062fa24ffdcd5adab51c8dcb25
--- /dev/null
+++ b/imaginaire/layers/weight_norm.py
@@ -0,0 +1,267 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import collections
+import functools
+
+import torch
+from torch import nn
+from torch.nn.utils import spectral_norm, weight_norm
+from torch.nn.utils.spectral_norm import SpectralNorm, \
+    SpectralNormStateDictHook, SpectralNormLoadStateDictPreHook
+
+from .conv import LinearBlock
+
+
+class WeightDemodulation(nn.Module):
+    r"""Weight demodulation in
+    "Analyzing and Improving the Image Quality of StyleGAN", Karras et al.
+
+    Args:
+        conv (torch.nn.Modules): Convolutional layer.
+        cond_dims (int): The number of channels in the conditional input.
+        eps (float, optional, default=1e-8): a value added to the
+            denominator for numerical stability.
+        adaptive_bias (bool, optional, default=False): If ``True``, adaptively
+            predicts bias from the conditional input.
+        demod (bool, optional, default=False): If ``True``, performs
+            weight demodulation.
+    """
+
+    def __init__(self, conv, cond_dims, eps=1e-8,
+                 adaptive_bias=False, demod=True):
+        super().__init__()
+        self.conv = conv
+        self.adaptive_bias = adaptive_bias
+        if adaptive_bias:
+            self.conv.register_parameter('bias', None)
+            self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels)
+        self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels)
+        self.eps = eps
+        self.demod = demod
+        self.conditional = True
+
+    def forward(self, x, y, **_kwargs):
+        r"""Weight demodulation forward"""
+        b, c, h, w = x.size()
+        self.conv.groups = b
+        gamma = self.fc_gamma(y)
+        gamma = gamma[:, None, :, None, None]
+        weight = self.conv.weight[None, :, :, :, :] * gamma
+
+        if self.demod:
+            d = torch.rsqrt(
+                (weight ** 2).sum(
+                    dim=(2, 3, 4), keepdim=True) + self.eps)
+            weight = weight * d
+
+        x = x.reshape(1, -1, h, w)
+        _, _, *ws = weight.shape
+        weight = weight.reshape(b * self.conv.out_channels, *ws)
+        x = self.conv._conv_forward(x, weight)
+
+        x = x.reshape(-1, self.conv.out_channels, h, w)
+        if self.adaptive_bias:
+            x += self.fc_beta(y)[:, :, None, None]
+        return x
+
+
+def weight_demod(
+        conv, cond_dims=256, eps=1e-8, adaptive_bias=False, demod=True):
+    r"""Weight demodulation."""
+    return WeightDemodulation(conv, cond_dims, eps, adaptive_bias, demod)
+
+
+class ScaledLR(object):
+    def __init__(self, weight_name, bias_name):
+        self.weight_name = weight_name
+        self.bias_name = bias_name
+
+    def compute_weight(self, module):
+        weight = getattr(module, self.weight_name + '_ori')
+        return weight * module.weight_scale
+
+    def compute_bias(self, module):
+        bias = getattr(module, self.bias_name + '_ori')
+        if bias is not None:
+            return bias * module.bias_scale
+        else:
+            return None
+
+    @staticmethod
+    def apply(module, weight_name, bias_name, lr_mul, equalized):
+        assert weight_name == 'weight'
+        assert bias_name == 'bias'
+        fn = ScaledLR(weight_name, bias_name)
+        module.register_forward_pre_hook(fn)
+
+        if hasattr(module, bias_name):
+            # module.bias is a parameter (can be None).
+            bias = getattr(module, bias_name)
+            delattr(module, bias_name)
+            module.register_parameter(bias_name + '_ori', bias)
+        else:
+            # module.bias does not exist.
+            bias = None
+            setattr(module, bias_name + '_ori', bias)
+        if bias is not None:
+            setattr(module, bias_name, bias.data)
+        else:
+            setattr(module, bias_name, None)
+        module.register_buffer('bias_scale', torch.tensor(lr_mul))
+
+        if hasattr(module, weight_name + '_orig'):
+            # The module has been wrapped with spectral normalization.
+            # We only want to keep a single weight parameter.
+            weight = getattr(module, weight_name + '_orig')
+            delattr(module, weight_name + '_orig')
+            module.register_parameter(weight_name + '_ori', weight)
+            setattr(module, weight_name + '_orig', weight.data)
+            # Put this hook before the spectral norm hook.
+            module._forward_pre_hooks = collections.OrderedDict(
+                reversed(list(module._forward_pre_hooks.items()))
+            )
+            module.use_sn = True
+        else:
+            weight = getattr(module, weight_name)
+            delattr(module, weight_name)
+            module.register_parameter(weight_name + '_ori', weight)
+            setattr(module, weight_name, weight.data)
+            module.use_sn = False
+
+        # assert weight.dim() == 4 or weight.dim() == 2
+        if equalized:
+            fan_in = weight.data.size(1) * weight.data[0][0].numel()
+            # Theoretically, the gain should be sqrt(2) instead of 1.
+            # The official StyleGAN2 uses 1 for some reason.
+            module.register_buffer(
+                'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5))
+            )
+        else:
+            module.register_buffer('weight_scale', torch.tensor(lr_mul))
+
+        module.lr_mul = module.weight_scale
+        module.base_lr_mul = lr_mul
+
+        return fn
+
+    def remove(self, module):
+        with torch.no_grad():
+            weight = self.compute_weight(module)
+        delattr(module, self.weight_name + '_ori')
+
+        if module.use_sn:
+            setattr(module, self.weight_name + '_orig', weight.detach())
+        else:
+            delattr(module, self.weight_name)
+            module.register_parameter(self.weight_name,
+                                      torch.nn.Parameter(weight.detach()))
+
+        with torch.no_grad():
+            bias = self.compute_bias(module)
+        delattr(module, self.bias_name)
+        delattr(module, self.bias_name + '_ori')
+        if bias is not None:
+            module.register_parameter(self.bias_name,
+                                      torch.nn.Parameter(bias.detach()))
+        else:
+            module.register_parameter(self.bias_name, None)
+
+        module.lr_mul = 1.0
+        module.base_lr_mul = 1.0
+
+    def __call__(self, module, input):
+        weight = self.compute_weight(module)
+        if module.use_sn:
+            # The following spectral norm hook will compute the SN of
+            # "module.weight_orig" and store the normalized weight in
+            # "module.weight".
+            setattr(module, self.weight_name + '_orig', weight)
+        else:
+            setattr(module, self.weight_name, weight)
+        bias = self.compute_bias(module)
+        setattr(module, self.bias_name, bias)
+
+
+def remove_weight_norms(module, weight_name='weight', bias_name='bias'):
+    if hasattr(module, 'weight_ori') or hasattr(module, 'weight_orig'):
+        for k in list(module._forward_pre_hooks.keys()):
+            hook = module._forward_pre_hooks[k]
+            if (isinstance(hook, ScaledLR) or isinstance(hook, SpectralNorm)):
+                hook.remove(module)
+                del module._forward_pre_hooks[k]
+
+        for k, hook in module._state_dict_hooks.items():
+            if isinstance(hook, SpectralNormStateDictHook) and \
+                    hook.fn.name == weight_name:
+                del module._state_dict_hooks[k]
+                break
+
+        for k, hook in module._load_state_dict_pre_hooks.items():
+            if isinstance(hook, SpectralNormLoadStateDictPreHook) and \
+                    hook.fn.name == weight_name:
+                del module._load_state_dict_pre_hooks[k]
+                break
+
+    return module
+
+
+def remove_equalized_lr(module, weight_name='weight', bias_name='bias'):
+    for k, hook in module._forward_pre_hooks.items():
+        if isinstance(hook, ScaledLR) and hook.weight_name == weight_name:
+            hook.remove(module)
+            del module._forward_pre_hooks[k]
+            break
+    else:
+        raise ValueError("Equalized learning rate not found")
+
+    return module
+
+
+def scaled_lr(
+        module, weight_name='weight', bias_name='bias', lr_mul=1.,
+        equalized=False,
+):
+    ScaledLR.apply(module, weight_name, bias_name, lr_mul, equalized)
+    return module
+
+
+def get_weight_norm_layer(norm_type, **norm_params):
+    r"""Return weight normalization.
+
+    Args:
+        norm_type (str):
+            Type of weight normalization.
+            ``'none'``, ``'spectral'``, ``'weight'``
+            or ``'weight_demod'``.
+        norm_params: Arbitrary keyword arguments that will be used to
+            initialize the weight normalization.
+    """
+    if norm_type == 'none' or norm_type == '':  # no normalization
+        return lambda x: x
+    elif norm_type == 'spectral':  # spectral normalization
+        return functools.partial(spectral_norm, **norm_params)
+    elif norm_type == 'weight':  # weight normalization
+        return functools.partial(weight_norm, **norm_params)
+    elif norm_type == 'weight_demod':  # weight demodulation
+        return functools.partial(weight_demod, **norm_params)
+    elif norm_type == 'equalized_lr':  # equalized learning rate
+        return functools.partial(scaled_lr, equalized=True, **norm_params)
+    elif norm_type == 'scaled_lr':  # equalized learning rate
+        return functools.partial(scaled_lr, **norm_params)
+    elif norm_type == 'equalized_lr_spectral':
+        lr_mul = norm_params.pop('lr_mul', 1.0)
+        return lambda x: functools.partial(
+            scaled_lr, equalized=True, lr_mul=lr_mul)(
+            functools.partial(spectral_norm, **norm_params)(x)
+        )
+    elif norm_type == 'scaled_lr_spectral':
+        lr_mul = norm_params.pop('lr_mul', 1.0)
+        return lambda x: functools.partial(
+            scaled_lr, lr_mul=lr_mul)(
+            functools.partial(spectral_norm, **norm_params)(x)
+        )
+    else:
+        raise ValueError(
+            'Weight norm layer %s is not recognized' % norm_type)
diff --git a/imaginaire/losses/TVloss.py b/imaginaire/losses/TVloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..158b12519e90c580df1c71658c86b47ce8e71a6e
--- /dev/null
+++ b/imaginaire/losses/TVloss.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+
+
+class TV_loss(nn.Module):
+    def __init__(self):
+        super().__init__()
+    
+    def forward(self,input):
+        B,D1,D2,D3 = input.size()
+        tv_d1 = torch.pow(input[:,1:,:,:]-input[:,:-1,:,:], 2).sum()
+        tv_d2 = torch.pow(input[:,:,1:,:]-input[:,:,:-1,:], 2).sum()
+        tv_d3 = torch.pow(input[:,:,:,1:]-input[:,:,:,:-1], 2).sum()
+        return (tv_d1+tv_d2+tv_d3)/(B*D1*D2*D3)
+
diff --git a/imaginaire/losses/__init__.py b/imaginaire/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3b7d8e381cd8fe0771dd2ce2478af0986a9be87
--- /dev/null
+++ b/imaginaire/losses/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .gan import GANLoss
+from .perceptual import PerceptualLoss
+from .feature_matching import FeatureMatchingLoss
+from .kl import GaussianKLLoss
+from .flow import MaskedL1Loss, FlowLoss
+from .dict import DictLoss
+from .weighted_mse import WeightedMSELoss
+from .TVloss import TV_loss
+
+__all__ = ['GANLoss', 'PerceptualLoss', 'FeatureMatchingLoss', 'GaussianKLLoss',
+           'MaskedL1Loss', 'FlowLoss', 'DictLoss',
+           'WeightedMSELoss','TV_loss']
+
+try:
+    from .gradient_penalty import GradientPenaltyLoss
+    __all__.extend(['GradientPenaltyLoss'])
+except:  # noqa
+    pass
diff --git a/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc b/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e9cc54635ab39fc68ce1bf7fa42f3880f15ca82
Binary files /dev/null and b/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/__init__.cpython-38.pyc b/imaginaire/losses/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a283a2113362dc366af1fa2376349246babbd54b
Binary files /dev/null and b/imaginaire/losses/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/dict.cpython-38.pyc b/imaginaire/losses/__pycache__/dict.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..899655408ca317238e65846264a0fcbadd04a4d1
Binary files /dev/null and b/imaginaire/losses/__pycache__/dict.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc b/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbeb6414e6aeb71b2ec64402751478cb786495c4
Binary files /dev/null and b/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/flow.cpython-38.pyc b/imaginaire/losses/__pycache__/flow.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1043723e126ca15751d7651e0e848b69ea7df80e
Binary files /dev/null and b/imaginaire/losses/__pycache__/flow.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/gan.cpython-38.pyc b/imaginaire/losses/__pycache__/gan.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fa848211b872fdf205c9c30b31d19bb2f3aae0e
Binary files /dev/null and b/imaginaire/losses/__pycache__/gan.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc b/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e56e1ed11ac004f1aa741f38c930f033457feca
Binary files /dev/null and b/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/kl.cpython-38.pyc b/imaginaire/losses/__pycache__/kl.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..140c611c8fda1667885013dfd8f35225b00bd8fb
Binary files /dev/null and b/imaginaire/losses/__pycache__/kl.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc b/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3aa1b81d9bc18105b7f3c3940b4b49831f19a22b
Binary files /dev/null and b/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc b/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e0a751c997154fbaec8d66e1527da42ff45bb73
Binary files /dev/null and b/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc differ
diff --git a/imaginaire/losses/dict.py b/imaginaire/losses/dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d0e5ee9516013036b92c564fa349755fc9fc1c9
--- /dev/null
+++ b/imaginaire/losses/dict.py
@@ -0,0 +1,36 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.nn as nn
+
+
+class DictLoss(nn.Module):
+    def __init__(self, criterion='l1'):
+        super(DictLoss, self).__init__()
+        if criterion == 'l1':
+            self.criterion = nn.L1Loss()
+        elif criterion == 'l2' or criterion == 'mse':
+            self.criterion = nn.MSELoss()
+        else:
+            raise ValueError('Criterion %s is not recognized' % criterion)
+
+    def forward(self, fake, real):
+        """Return the target vector for the l1/l2 loss computation.
+
+        Args:
+           fake (dict, list or tuple): Discriminator features of fake images.
+           real (dict, list or tuple): Discriminator features of real images.
+        Returns:
+           loss (tensor): Loss value.
+        """
+        loss = 0
+        if type(fake) == dict:
+            for key in fake.keys():
+                loss += self.criterion(fake[key], real[key].detach())
+        elif type(fake) == list or type(fake) == tuple:
+            for f, r in zip(fake, real):
+                loss += self.criterion(f, r.detach())
+        else:
+            loss += self.criterion(fake, real.detach())
+        return loss
diff --git a/imaginaire/losses/feature_matching.py b/imaginaire/losses/feature_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70034b3c3afba5a914261b55cf0abeab832391c
--- /dev/null
+++ b/imaginaire/losses/feature_matching.py
@@ -0,0 +1,38 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.nn as nn
+
+
+class FeatureMatchingLoss(nn.Module):
+    r"""Compute feature matching loss"""
+    def __init__(self, criterion='l1'):
+        super(FeatureMatchingLoss, self).__init__()
+        if criterion == 'l1':
+            self.criterion = nn.L1Loss()
+        elif criterion == 'l2' or criterion == 'mse':
+            self.criterion = nn.MSELoss()
+        else:
+            raise ValueError('Criterion %s is not recognized' % criterion)
+
+    def forward(self, fake_features, real_features):
+        r"""Return the target vector for the binary cross entropy loss
+        computation.
+
+        Args:
+           fake_features (list of lists): Discriminator features of fake images.
+           real_features (list of lists): Discriminator features of real images.
+
+        Returns:
+           (tensor): Loss value.
+        """
+        num_d = len(fake_features)
+        dis_weight = 1.0 / num_d
+        loss = fake_features[0][0].new_tensor(0)
+        for i in range(num_d):
+            for j in range(len(fake_features[i])):
+                tmp_loss = self.criterion(fake_features[i][j],
+                                          real_features[i][j].detach())
+                loss += dis_weight * tmp_loss
+        return loss
diff --git a/imaginaire/losses/flow.py b/imaginaire/losses/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..3464d949bba8cf148f7bc5c8fa092995f66c663a
--- /dev/null
+++ b/imaginaire/losses/flow.py
@@ -0,0 +1,313 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa
+import importlib
+import warnings
+
+import torch
+import torch.nn as nn
+
+from imaginaire.model_utils.fs_vid2vid import (get_face_mask, get_fg_mask,
+                                               get_part_mask, pick_image,
+                                               resample)
+
+
+class MaskedL1Loss(nn.Module):
+    r"""Masked L1 loss constructor."""
+
+    def __init__(self, normalize_over_valid=False):
+        super(MaskedL1Loss, self).__init__()
+        self.criterion = nn.L1Loss()
+        self.normalize_over_valid = normalize_over_valid
+
+    def forward(self, input, target, mask):
+        r"""Masked L1 loss computation.
+
+        Args:
+            input (tensor): Input tensor.
+            target (tensor): Target tensor.
+            mask (tensor): Mask to be applied to the output loss.
+
+        Returns:
+            (tensor): Loss value.
+        """
+        mask = mask.expand_as(input)
+        loss = self.criterion(input * mask, target * mask)
+        if self.normalize_over_valid:
+            # The loss has been averaged over all pixels.
+            # Only average over regions which are valid.
+            loss = loss * torch.numel(mask) / (torch.sum(mask) + 1e-6)
+        return loss
+
+
+class FlowLoss(nn.Module):
+    r"""Flow loss constructor.
+
+    Args:
+        cfg (obj): Configuration.
+    """
+
+    def __init__(self, cfg):
+        super(FlowLoss, self).__init__()
+        self.cfg = cfg
+        self.data_cfg = cfg.data
+        self.criterion = nn.L1Loss()
+        self.criterionMasked = MaskedL1Loss()
+        flow_module = importlib.import_module(cfg.flow_network.type)
+        self.flowNet = flow_module.FlowNet(pretrained=True)
+        self.warp_ref = getattr(cfg.gen.flow, 'warp_ref', False)
+        self.pose_cfg = pose_cfg = getattr(cfg.data, 'for_pose_dataset', None)
+        self.for_pose_dataset = pose_cfg is not None
+        self.has_fg = getattr(cfg.data, 'has_foreground', False)
+
+    def forward(self, data, net_G_output, current_epoch):
+        r"""Compute losses on the output flow and occlusion mask.
+
+        Args:
+            data (dict): Input data.
+            net_G_output (dict): Generator output.
+            current_epoch (int): Current training epoch number.
+        Returns:
+            (dict):
+              - loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
+              - loss_flow_warp (tensor): L1 loss between the warped image and the
+                target image when using the flow to warp.
+              - loss_mask (tensor): Loss for the occlusion mask.
+        """
+        tgt_label, tgt_image = data['label'], data['image']
+
+        fake_image = net_G_output['fake_images']
+        warped_images = net_G_output['warped_images']
+        flow = net_G_output['fake_flow_maps']
+        occ_mask = net_G_output['fake_occlusion_masks']
+
+        if self.warp_ref:
+            # Pick the most similar reference image to warp.
+            ref_labels, ref_images = data['ref_labels'], data['ref_images']
+            ref_idx = net_G_output['ref_idx']
+            ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx)
+        else:
+            ref_label = ref_image = None
+
+        # Compute the ground truth flows and confidence maps.
+        flow_gt_prev = flow_gt_ref = conf_gt_prev = conf_gt_ref = None
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            if self.warp_ref:
+                # Compute GT for warping reference -> target.
+                if self.for_pose_dataset:
+                    # Use DensePose maps to compute flows for pose dataset.
+                    flow_gt_ref, conf_gt_ref = self.flowNet(tgt_label[:, :3],
+                                                            ref_label[:, :3])
+                else:
+                    # Use RGB images for other datasets.
+                    flow_gt_ref, conf_gt_ref = self.flowNet(tgt_image,
+                                                            ref_image)
+
+            if current_epoch >= self.cfg.single_frame_epoch and \
+                    data['real_prev_image'] is not None:
+                # Compute GT for warping previous -> target.
+                tgt_image_prev = data['real_prev_image']
+                flow_gt_prev, conf_gt_prev = self.flowNet(tgt_image,
+                                                          tgt_image_prev)
+
+        flow_gt = [flow_gt_ref, flow_gt_prev]
+        flow_conf_gt = [conf_gt_ref, conf_gt_prev]
+        # Get the foreground masks.
+        fg_mask, ref_fg_mask = get_fg_mask([tgt_label, ref_label], self.has_fg)
+
+        # Compute losses for flow maps and masks.
+        loss_flow_L1, loss_flow_warp, body_mask_diff = \
+            self.compute_flow_losses(flow, warped_images, tgt_image, flow_gt,
+                                     flow_conf_gt, fg_mask, tgt_label,
+                                     ref_label)
+
+        loss_mask = self.compute_mask_losses(
+            occ_mask, fake_image, warped_images, tgt_label, tgt_image,
+            fg_mask, ref_fg_mask, body_mask_diff)
+
+        return loss_flow_L1, loss_flow_warp, loss_mask
+
+    def compute_flow_losses(self, flow, warped_images, tgt_image, flow_gt,
+                            flow_conf_gt, fg_mask, tgt_label, ref_label):
+        r"""Compute losses on the generated flow maps.
+
+        Args:
+            flow (tensor or list of tensors): Generated flow maps.
+                warped_images (tensor or list of tensors): Warped images using the
+                flow maps.
+            tgt_image (tensor): Target image for the warped image.
+                flow_gt (tensor or list of tensors): Ground truth flow maps.
+            flow_conf_gt (tensor or list of tensors): Confidence for the ground
+                truth flow maps.
+            fg_mask (tensor): Foreground mask for the target image.
+            tgt_label (tensor): Target label map.
+            ref_label (tensor): Reference label map.
+        Returns:
+            (dict):
+              - loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
+              - loss_flow_warp (tensor): L1 loss between the warped image and the
+                target image when using the flow to warp.
+              - body_mask_diff (tensor): Difference between warped body part map
+                and target body part map. Used for pose dataset only.
+        """
+        loss_flow_L1 = torch.tensor(0., device=torch.device('cuda'))
+        loss_flow_warp = torch.tensor(0., device=torch.device('cuda'))
+        if isinstance(flow, list):
+            # Compute flow losses for both warping reference -> target and
+            # previous -> target.
+            for i in range(len(flow)):
+                loss_flow_L1_i, loss_flow_warp_i = \
+                    self.compute_flow_loss(flow[i], warped_images[i], tgt_image,
+                                           flow_gt[i], flow_conf_gt[i], fg_mask)
+                loss_flow_L1 += loss_flow_L1_i
+                loss_flow_warp += loss_flow_warp_i
+        else:
+            # Compute loss for warping either reference or previous images.
+            loss_flow_L1, loss_flow_warp = \
+                self.compute_flow_loss(flow, warped_images, tgt_image,
+                                       flow_gt[-1], flow_conf_gt[-1], fg_mask)
+
+        # For pose dataset only.
+        body_mask_diff = None
+        if self.warp_ref:
+            if self.for_pose_dataset:
+                # Warped reference body part map should be similar to target
+                # body part map.
+                body_mask = get_part_mask(tgt_label[:, 2])
+                ref_body_mask = get_part_mask(ref_label[:, 2])
+                warped_ref_body_mask = resample(ref_body_mask, flow[0])
+                loss_flow_warp += self.criterion(warped_ref_body_mask,
+                                                 body_mask)
+                body_mask_diff = torch.sum(
+                    abs(warped_ref_body_mask - body_mask), dim=1, keepdim=True)
+
+            if self.has_fg:
+                # Warped reference foreground map should be similar to target
+                # foreground map.
+                fg_mask, ref_fg_mask = \
+                    get_fg_mask([tgt_label, ref_label], True)
+                warped_ref_fg_mask = resample(ref_fg_mask, flow[0])
+                loss_flow_warp += self.criterion(warped_ref_fg_mask, fg_mask)
+
+        return loss_flow_L1, loss_flow_warp, body_mask_diff
+
+    def compute_flow_loss(self, flow, warped_image, tgt_image, flow_gt,
+                          flow_conf_gt, fg_mask):
+        r"""Compute losses on the generated flow map.
+
+        Args:
+            flow (tensor): Generated flow map.
+            warped_image (tensor): Warped image using the flow map.
+            tgt_image (tensor): Target image for the warped image.
+            flow_gt (tensor): Ground truth flow map.
+            flow_conf_gt (tensor): Confidence for the ground truth flow map.
+            fg_mask (tensor): Foreground mask for the target image.
+        Returns:
+            (dict):
+              - loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
+              - loss_flow_warp (tensor): L1 loss between the warped image and
+              the target image when using the flow to warp.
+        """
+        loss_flow_L1 = torch.tensor(0., device=torch.device('cuda'))
+        loss_flow_warp = torch.tensor(0., device=torch.device('cuda'))
+        if flow is not None and flow_gt is not None:
+            # L1 loss compared to flow ground truth.
+            loss_flow_L1 = self.criterionMasked(flow, flow_gt,
+                                                flow_conf_gt * fg_mask)
+        if warped_image is not None:
+            # L1 loss between warped image and target image.
+            loss_flow_warp = self.criterion(warped_image, tgt_image)
+        return loss_flow_L1, loss_flow_warp
+
+    def compute_mask_losses(self, occ_mask, fake_image, warped_image,
+                            tgt_label, tgt_image, fg_mask, ref_fg_mask,
+                            body_mask_diff):
+        r"""Compute losses on the generated occlusion masks.
+
+        Args:
+            occ_mask (tensor or list of tensors): Generated occlusion masks.
+            fake_image (tensor): Generated image.
+            warped_image (tensor or list of tensors): Warped images using the
+                flow maps.
+            tgt_label (tensor): Target label map.
+            tgt_image (tensor): Target image for the warped image.
+            fg_mask (tensor): Foreground mask for the target image.
+            ref_fg_mask (tensor): Foreground mask for the reference image.
+            body_mask_diff (tensor): Difference between warped body part map
+            and target body part map. Used for pose dataset only.
+        Returns:
+            (tensor): Loss for the mask.
+        """
+        loss_mask = torch.tensor(0., device=torch.device('cuda'))
+
+        if isinstance(occ_mask, list):
+            # Compute occlusion mask losses for both warping reference -> target
+            # and previous -> target.
+            for i in range(len(occ_mask)):
+                loss_mask += self.compute_mask_loss(occ_mask[i],
+                                                    warped_image[i],
+                                                    tgt_image)
+        else:
+            # Compute loss for warping either reference or previous images.
+            loss_mask += self.compute_mask_loss(occ_mask, warped_image,
+                                                tgt_image)
+
+        if self.warp_ref:
+            ref_occ_mask = occ_mask[0]
+            dummy0 = torch.zeros_like(ref_occ_mask)
+            dummy1 = torch.ones_like(ref_occ_mask)
+            if self.for_pose_dataset:
+                # Enforce output to use more warped reference image for
+                # face region.
+                face_mask = get_face_mask(tgt_label[:, 2]).unsqueeze(1)
+                AvgPool = torch.nn.AvgPool2d(15, padding=7, stride=1)
+                face_mask = AvgPool(face_mask)
+                loss_mask += self.criterionMasked(ref_occ_mask, dummy0,
+                                                  face_mask)
+                loss_mask += self.criterionMasked(fake_image, warped_image[0],
+                                                  face_mask)
+                # Enforce output to use more hallucinated image for discrepancy
+                # regions of body part masks between warped reference and
+                # target image.
+                loss_mask += self.criterionMasked(ref_occ_mask, dummy1,
+                                                  body_mask_diff)
+
+            if self.has_fg:
+                # Enforce output to use more hallucinated image for discrepancy
+                # regions of foreground masks between reference and target
+                # image.
+                fg_mask_diff = ((ref_fg_mask - fg_mask) > 0).float()
+                loss_mask += self.criterionMasked(ref_occ_mask, dummy1,
+                                                  fg_mask_diff)
+        return loss_mask
+
+    def compute_mask_loss(self, occ_mask, warped_image, tgt_image):
+        r"""Compute losses on the generated occlusion mask.
+
+        Args:
+            occ_mask (tensor): Generated occlusion mask.
+            warped_image (tensor): Warped image using the flow map.
+            tgt_image (tensor): Target image for the warped image.
+        Returns:
+            (tensor): Loss for the mask.
+        """
+        loss_mask = torch.tensor(0., device=torch.device('cuda'))
+        if occ_mask is not None:
+            dummy0 = torch.zeros_like(occ_mask)
+            dummy1 = torch.ones_like(occ_mask)
+
+            # Compute the confidence map based on L1 distance between warped
+            # and GT image.
+            img_diff = torch.sum(abs(warped_image - tgt_image), dim=1,
+                                 keepdim=True)
+            conf = torch.clamp(1 - img_diff, 0, 1)
+
+            # Force mask value to be small if warped image is similar to GT,
+            # and vice versa.
+            loss_mask = self.criterionMasked(occ_mask, dummy0, conf)
+            loss_mask += self.criterionMasked(occ_mask, dummy1, 1 - conf)
+
+        return loss_mask
diff --git a/imaginaire/losses/gan.py b/imaginaire/losses/gan.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaa9c30dd51887b25b439b90fa728e94fe2b03a9
--- /dev/null
+++ b/imaginaire/losses/gan.py
@@ -0,0 +1,173 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.utils.distributed import master_only_print as print
+
+
+@torch.jit.script
+def fuse_math_min_mean_pos(x):
+    r"""Fuse operation min mean for hinge loss computation of positive
+    samples"""
+    minval = torch.min(x - 1, x * 0)
+    loss = -torch.mean(minval)
+    return loss
+
+
+@torch.jit.script
+def fuse_math_min_mean_neg(x):
+    r"""Fuse operation min mean for hinge loss computation of negative
+    samples"""
+    minval = torch.min(-x - 1, x * 0)
+    loss = -torch.mean(minval)
+    return loss
+
+
+class GANLoss(nn.Module):
+    r"""GAN loss constructor.
+
+    Args:
+        gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``,
+            ``'non_saturated'``, ``'wasserstein'``.
+        target_real_label (float): The desired output label for real images.
+        target_fake_label (float): The desired output label for fake images.
+        decay_k (float): The decay factor per epoch for top-k training.
+        min_k (float): The minimum percentage of samples to select.
+        separate_topk (bool): If ``True``, selects top-k for each sample
+            separately, otherwise selects top-k among all samples.
+    """
+    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
+                 decay_k=1., min_k=1., separate_topk=False):
+        super(GANLoss, self).__init__()
+        self.real_label = target_real_label
+        self.fake_label = target_fake_label
+        self.real_label_tensor = None
+        self.fake_label_tensor = None
+        self.gan_mode = gan_mode
+        self.decay_k = decay_k
+        self.min_k = min_k
+        self.separate_topk = separate_topk
+        self.register_buffer('k', torch.tensor(1.0))
+        print('GAN mode: %s' % gan_mode)
+
+    def forward(self, dis_output, t_real, dis_update=True, reduce=True):
+        r"""GAN loss computation.
+
+        Args:
+            dis_output (tensor or list of tensors): Discriminator outputs.
+            t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target.
+            dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator.
+            reduce (bool): If ``True``, when a list of discriminator outputs are provided, it will return the average
+                of all losses, otherwise it will return a list of losses.
+        Returns:
+            loss (tensor): Loss value.
+        """
+        if isinstance(dis_output, list):
+            # For multi-scale discriminators.
+            # In this implementation, the loss is first averaged for each scale
+            # (batch size and number of locations) then averaged across scales,
+            # so that the gradient is not dominated by the discriminator that
+            # has the most output values (highest resolution).
+            losses = []
+            for dis_output_i in dis_output:
+                assert isinstance(dis_output_i, torch.Tensor)
+                losses.append(self.loss(dis_output_i, t_real, dis_update))
+            if reduce:
+                return torch.mean(torch.stack(losses))
+            else:
+                return losses
+        else:
+            return self.loss(dis_output, t_real, dis_update)
+
+    def loss(self, dis_output, t_real, dis_update=True):
+        r"""GAN loss computation.
+
+        Args:
+            dis_output (tensor): Discriminator outputs.
+            t_real (bool): If ``True``, uses the real label as target, otherwise
+                uses the fake label as target.
+            dis_update (bool): Updating the discriminator or the generator.
+        Returns:
+            loss (tensor): Loss value.
+        """
+        if not dis_update:
+            assert t_real, \
+                "The target should be real when updating the generator."
+
+        if not dis_update and self.k < 1:
+            r"""
+            Use top-k training:
+            "Top-k Training of GANs: Improving GAN Performance by Throwing
+            Away Bad Samples"
+            Here, each sample may have multiple discriminator output values
+            (patch discriminator). We could either select top-k for each sample
+            separately (when ``self.separate_topk=True``), or collect values
+            from all samples and then select top-k (default, when
+            ``self.separate_topk=False``).
+            """
+            if self.separate_topk:
+                dis_output = dis_output.view(dis_output.size(0), -1)
+            else:
+                dis_output = dis_output.view(-1)
+            k = math.ceil(self.k * dis_output.size(-1))
+            dis_output, _ = torch.topk(dis_output, k)
+
+        if self.gan_mode == 'non_saturated':
+            target_tensor = self.get_target_tensor(dis_output, t_real)
+            loss = F.binary_cross_entropy_with_logits(dis_output,
+                                                      target_tensor)
+        elif self.gan_mode == 'least_square':
+            target_tensor = self.get_target_tensor(dis_output, t_real)
+            loss = 0.5 * F.mse_loss(dis_output, target_tensor)
+        elif self.gan_mode == 'hinge':
+            if dis_update:
+                if t_real:
+                    loss = fuse_math_min_mean_pos(dis_output)
+                else:
+                    loss = fuse_math_min_mean_neg(dis_output)
+            else:
+                loss = -torch.mean(dis_output)
+        elif self.gan_mode == 'wasserstein':
+            if t_real:
+                loss = -torch.mean(dis_output)
+            else:
+                loss = torch.mean(dis_output)
+        elif self.gan_mode == 'softplus':
+            target_tensor = self.get_target_tensor(dis_output, t_real)
+            loss = F.binary_cross_entropy_with_logits(dis_output,
+                                                      target_tensor)
+        else:
+            raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode))
+        return loss
+
+    def get_target_tensor(self, dis_output, t_real):
+        r"""Return the target vector for the binary cross entropy loss
+        computation.
+
+        Args:
+            dis_output (tensor): Discriminator outputs.
+            t_real (bool): If ``True``, uses the real label as target, otherwise
+                uses the fake label as target.
+        Returns:
+            target (tensor): Target tensor vector.
+        """
+        if t_real:
+            if self.real_label_tensor is None:
+                self.real_label_tensor = dis_output.new_tensor(self.real_label)
+            return self.real_label_tensor.expand_as(dis_output)
+        else:
+            if self.fake_label_tensor is None:
+                self.fake_label_tensor = dis_output.new_tensor(self.fake_label)
+            return self.fake_label_tensor.expand_as(dis_output)
+
+    def topk_anneal(self):
+        r"""Anneal k after each epoch."""
+        if self.decay_k < 1:
+            # noinspection PyAttributeOutsideInit
+            self.k.fill_(max(self.decay_k * self.k, self.min_k))
+            print("Top-k training: update k to {}.".format(self.k))
diff --git a/imaginaire/losses/info_nce.py b/imaginaire/losses/info_nce.py
new file mode 100644
index 0000000000000000000000000000000000000000..8033e828f0b99d12d6e8f8b71811982d0ab568f6
--- /dev/null
+++ b/imaginaire/losses/info_nce.py
@@ -0,0 +1,87 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from imaginaire.utils.distributed import get_world_size, get_rank, \
+    dist_all_reduce_tensor
+
+
+class GatherLayer(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input):
+        ctx.save_for_backward(input)
+        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
+        dist.all_gather(output, input)
+        return tuple(output)
+
+    @staticmethod
+    def backward(ctx, *grads):
+        input, = ctx.saved_tensors
+        grad_out = torch.zeros_like(input)
+        all_grads = torch.stack(grads)
+        all_grads = dist_all_reduce_tensor(all_grads, reduce='sum')
+        grad_out[:] = all_grads[get_rank()]
+        return grad_out
+
+
+class InfoNCELoss(nn.Module):
+    def __init__(self,
+                 temperature=0.07,
+                 gather_distributed=True,
+                 learn_temperature=True,
+                 single_direction=False,
+                 flatten=True):
+        super(InfoNCELoss, self).__init__()
+        self.logit_scale = nn.Parameter(torch.tensor([math.log(1/temperature)]))
+        self.logit_scale.requires_grad = learn_temperature
+        self.gather_distributed = gather_distributed
+        self.single_direction = single_direction
+        self.flatten = flatten
+
+    def forward(self, features_a, features_b, gather_distributed=None, eps=1e-8):
+        if gather_distributed is None:
+            gather_distributed = self.gather_distributed
+
+        if features_a is None or features_b is None:
+            return torch.tensor(0, device='cuda'), torch.tensor(0, device='cuda')
+
+        bs_a, bs_b = features_a.size(0), features_b.size(0)
+        if self.flatten:
+            features_a, features_b = features_a.reshape(bs_a, -1), features_b.reshape(bs_b, -1)
+        else:
+            features_a = features_a.reshape(bs_a, features_a.size(1), -1).mean(-1)
+            features_b = features_b.reshape(bs_b, features_b.size(1), -1).mean(-1)
+
+        # Temperature clipping.
+        self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052)
+
+        # normalized features
+        features_a = features_a / (features_a.norm(dim=1, keepdim=True) + eps)
+        features_b = features_b / (features_b.norm(dim=1, keepdim=True) + eps)
+
+        loss_a = self._forward_single_direction(features_a, features_b, gather_distributed)
+        if self.single_direction:
+            return loss_a
+        else:
+            loss_b = self._forward_single_direction(features_b, features_a, gather_distributed)
+            return loss_a + loss_b
+
+    def _forward_single_direction(
+            self, features_a, features_b, gather_distributed):
+        bs_a = features_a.shape[0]
+        logit_scale = self.logit_scale.exp()
+        if get_world_size() > 1 and gather_distributed:
+            gather_features_b = torch.cat(GatherLayer.apply(features_b))
+            gather_labels_a = torch.arange(bs_a, device='cuda') + get_rank() * bs_a
+            logits_a = logit_scale * features_a @ gather_features_b.t()
+        else:
+            gather_labels_a = torch.arange(bs_a, device='cuda')
+            logits_a = logit_scale * features_a @ features_b.t()
+        loss_a = F.cross_entropy(logits_a, gather_labels_a)
+        return loss_a
diff --git a/imaginaire/losses/kl.py b/imaginaire/losses/kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc9a2da14b06ccfef143312acd20ccd3784bdb34
--- /dev/null
+++ b/imaginaire/losses/kl.py
@@ -0,0 +1,22 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+
+class GaussianKLLoss(nn.Module):
+    r"""Compute KL loss in VAE for Gaussian distributions"""
+    def __init__(self):
+        super(GaussianKLLoss, self).__init__()
+
+    def forward(self, mu, logvar=None):
+        r"""Compute loss
+
+        Args:
+            mu (tensor): mean
+            logvar (tensor): logarithm of variance
+        """
+        if logvar is None:
+            logvar = torch.zeros_like(mu)
+        return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
diff --git a/imaginaire/losses/perceptual.py b/imaginaire/losses/perceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..424656fa09b65333e4fa28cae2de7114de69ebfa
--- /dev/null
+++ b/imaginaire/losses/perceptual.py
@@ -0,0 +1,395 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, distributed as dist
+
+from imaginaire.losses.info_nce import InfoNCELoss
+from imaginaire.utils.distributed import master_only_print as print, \
+    is_local_master
+from imaginaire.utils.misc import apply_imagenet_normalization, to_float
+
+
+class PerceptualLoss(nn.Module):
+    r"""Perceptual loss initialization.
+
+    Args:
+       network (str) : The name of the loss network: 'vgg16' | 'vgg19'.
+       layers (str or list of str) : The layers used to compute the loss.
+       weights (float or list of float : The loss weights of each layer.
+       criterion (str): The type of distance function: 'l1' | 'l2'.
+       resize (bool) : If ``True``, resize the input images to 224x224.
+       resize_mode (str): Algorithm used for resizing.
+       num_scales (int): The loss will be evaluated at original size and
+        this many times downsampled sizes.
+       per_sample_weight (bool): Output loss for individual samples in the
+        batch instead of mean loss.
+    """
+
+    def __init__(self, network='vgg19', layers='relu_4_1', weights=None,
+                 criterion='l1', resize=False, resize_mode='bilinear',
+                 num_scales=1, per_sample_weight=False,
+                 info_nce_temperature=0.07,
+                 info_nce_gather_distributed=True,
+                 info_nce_learn_temperature=True,
+                 info_nce_flatten=True):
+        super().__init__()
+        if isinstance(layers, str):
+            layers = [layers]
+        if weights is None:
+            weights = [1.] * len(layers)
+        elif isinstance(layers, float) or isinstance(layers, int):
+            weights = [weights]
+
+        if dist.is_initialized() and not is_local_master():
+            # Make sure only the first process in distributed training downloads
+            # the model, and the others will use the cache
+            # noinspection PyUnresolvedReferences
+            torch.distributed.barrier()
+
+        assert len(layers) == len(weights), \
+            'The number of layers (%s) must be equal to ' \
+            'the number of weights (%s).' % (len(layers), len(weights))
+        if network == 'vgg19':
+            self.model = _vgg19(layers)
+        elif network == 'vgg16':
+            self.model = _vgg16(layers)
+        elif network == 'alexnet':
+            self.model = _alexnet(layers)
+        elif network == 'inception_v3':
+            self.model = _inception_v3(layers)
+        elif network == 'resnet50':
+            self.model = _resnet50(layers)
+        elif network == 'robust_resnet50':
+            self.model = _robust_resnet50(layers)
+        elif network == 'vgg_face_dag':
+            self.model = _vgg_face_dag(layers)
+        else:
+            raise ValueError('Network %s is not recognized' % network)
+
+        if dist.is_initialized() and is_local_master():
+            # Make sure only the first process in distributed training downloads
+            # the model, and the others will use the cache
+            # noinspection PyUnresolvedReferences
+            torch.distributed.barrier()
+
+        self.num_scales = num_scales
+        self.layers = layers
+        self.weights = weights
+        reduction = 'mean' if not per_sample_weight else 'none'
+        if criterion == 'l1':
+            self.criterion = nn.L1Loss(reduction=reduction)
+        elif criterion == 'l2' or criterion == 'mse':
+            self.criterion = nn.MSELoss(reduction=reduction)
+        elif criterion == 'info_nce':
+            self.criterion = InfoNCELoss(
+                temperature=info_nce_temperature,
+                gather_distributed=info_nce_gather_distributed,
+                learn_temperature=info_nce_learn_temperature,
+                flatten=info_nce_flatten,
+                single_direction=True
+            )
+        else:
+            raise ValueError('Criterion %s is not recognized' % criterion)
+        self.resize = resize
+        self.resize_mode = resize_mode
+        print('Perceptual loss:')
+        print('\tMode: {}'.format(network))
+
+    def forward(self, inp, target, per_sample_weights=None):
+        r"""Perceptual loss forward.
+
+        Args:
+           inp (4D tensor) : Input tensor.
+           target (4D tensor) : Ground truth tensor, same shape as the input.
+           per_sample_weight (bool): Output loss for individual samples in the
+            batch instead of mean loss.
+        Returns:
+           (scalar tensor) : The perceptual loss.
+        """
+        if not torch.is_autocast_enabled():
+            inp, target = to_float([inp, target])
+
+        # Perceptual loss should operate in eval mode by default.
+        self.model.eval()
+        inp, target = apply_imagenet_normalization(inp), apply_imagenet_normalization(target)
+        if self.resize:
+            inp = F.interpolate(inp, mode=self.resize_mode, size=(224, 224), align_corners=False)
+            target = F.interpolate(target, mode=self.resize_mode, size=(224, 224), align_corners=False)
+
+        # Evaluate perceptual loss at each scale.
+        loss = 0
+        for scale in range(self.num_scales):
+            input_features, target_features = self.model(inp), self.model(target)
+
+            for layer, weight in zip(self.layers, self.weights):
+                # Example per-layer VGG19 loss values after applying
+                # [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting.
+                # relu_1_1, 0.014698
+                # relu_2_1, 0.085817
+                # relu_3_1, 0.349977
+                # relu_4_1, 0.544188
+                # relu_5_1, 0.906261
+                # print('%s, %f' % (
+                #     layer,
+                #     weight * self.criterion(
+                #                  input_features[layer],
+                #                  target_features[
+                #                  layer].detach()).item()))
+                l_tmp = self.criterion(input_features[layer], target_features[layer].detach())
+                if per_sample_weights is not None:
+                    l_tmp = l_tmp.mean(1).mean(1).mean(1)
+                loss += weight * l_tmp
+            # Downsample the input and target.
+            if scale != self.num_scales - 1:
+                inp = F.interpolate(
+                    inp, mode=self.resize_mode, scale_factor=0.5,
+                    align_corners=False, recompute_scale_factor=True)
+                target = F.interpolate(
+                    target, mode=self.resize_mode, scale_factor=0.5,
+                    align_corners=False, recompute_scale_factor=True)
+
+        return loss.float()
+
+
+class _PerceptualNetwork(nn.Module):
+    r"""The network that extracts features to compute the perceptual loss.
+
+    Args:
+        network (nn.Sequential) : The network that extracts features.
+        layer_name_mapping (dict) : The dictionary that
+            maps a layer's index to its name.
+        layers (list of str): The list of layer names that we are using.
+    """
+
+    def __init__(self, network, layer_name_mapping, layers):
+        super().__init__()
+        assert isinstance(network, nn.Sequential), \
+            'The network needs to be of type "nn.Sequential".'
+        self.network = network
+        self.layer_name_mapping = layer_name_mapping
+        self.layers = layers
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, x):
+        r"""Extract perceptual features."""
+        output = {}
+        for i, layer in enumerate(self.network):
+            x = layer(x)
+            layer_name = self.layer_name_mapping.get(i, None)
+            if layer_name in self.layers:
+                # If the current layer is used by the perceptual loss.
+                output[layer_name] = x
+        return output
+
+
+def _vgg19(layers):
+    r"""Get vgg19 layers"""
+    vgg = torchvision.models.vgg19(pretrained=True)
+    # network = vgg.features
+    network = torch.nn.Sequential(*(list(vgg.features) + [vgg.avgpool] + [nn.Flatten()] + list(vgg.classifier)))
+    layer_name_mapping = {1: 'relu_1_1',
+                          3: 'relu_1_2',
+                          6: 'relu_2_1',
+                          8: 'relu_2_2',
+                          11: 'relu_3_1',
+                          13: 'relu_3_2',
+                          15: 'relu_3_3',
+                          17: 'relu_3_4',
+                          20: 'relu_4_1',
+                          22: 'relu_4_2',
+                          24: 'relu_4_3',
+                          26: 'relu_4_4',
+                          29: 'relu_5_1',
+                          31: 'relu_5_2',
+                          33: 'relu_5_3',
+                          35: 'relu_5_4',
+                          36: 'pool_5',
+                          42: 'fc_2'}
+    return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _vgg16(layers):
+    r"""Get vgg16 layers"""
+    network = torchvision.models.vgg16(pretrained=True).features
+    layer_name_mapping = {1: 'relu_1_1',
+                          3: 'relu_1_2',
+                          6: 'relu_2_1',
+                          8: 'relu_2_2',
+                          11: 'relu_3_1',
+                          13: 'relu_3_2',
+                          15: 'relu_3_3',
+                          18: 'relu_4_1',
+                          20: 'relu_4_2',
+                          22: 'relu_4_3',
+                          25: 'relu_5_1'}
+    return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _alexnet(layers):
+    r"""Get alexnet layers"""
+    network = torchvision.models.alexnet(pretrained=True).features
+    layer_name_mapping = {0: 'conv_1',
+                          1: 'relu_1',
+                          3: 'conv_2',
+                          4: 'relu_2',
+                          6: 'conv_3',
+                          7: 'relu_3',
+                          8: 'conv_4',
+                          9: 'relu_4',
+                          10: 'conv_5',
+                          11: 'relu_5'}
+    return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _inception_v3(layers):
+    r"""Get inception v3 layers"""
+    inception = torchvision.models.inception_v3(pretrained=True)
+    network = nn.Sequential(inception.Conv2d_1a_3x3,
+                            inception.Conv2d_2a_3x3,
+                            inception.Conv2d_2b_3x3,
+                            nn.MaxPool2d(kernel_size=3, stride=2),
+                            inception.Conv2d_3b_1x1,
+                            inception.Conv2d_4a_3x3,
+                            nn.MaxPool2d(kernel_size=3, stride=2),
+                            inception.Mixed_5b,
+                            inception.Mixed_5c,
+                            inception.Mixed_5d,
+                            inception.Mixed_6a,
+                            inception.Mixed_6b,
+                            inception.Mixed_6c,
+                            inception.Mixed_6d,
+                            inception.Mixed_6e,
+                            inception.Mixed_7a,
+                            inception.Mixed_7b,
+                            inception.Mixed_7c,
+                            nn.AdaptiveAvgPool2d(output_size=(1, 1)))
+    layer_name_mapping = {3: 'pool_1',
+                          6: 'pool_2',
+                          14: 'mixed_6e',
+                          18: 'pool_3'}
+    return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _resnet50(layers):
+    r"""Get resnet50 layers"""
+    resnet50 = torchvision.models.resnet50(pretrained=True)
+    network = nn.Sequential(resnet50.conv1,
+                            resnet50.bn1,
+                            resnet50.relu,
+                            resnet50.maxpool,
+                            resnet50.layer1,
+                            resnet50.layer2,
+                            resnet50.layer3,
+                            resnet50.layer4,
+                            resnet50.avgpool)
+    layer_name_mapping = {4: 'layer_1',
+                          5: 'layer_2',
+                          6: 'layer_3',
+                          7: 'layer_4'}
+    return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _robust_resnet50(layers):
+    r"""Get robust resnet50 layers"""
+    resnet50 = torchvision.models.resnet50(pretrained=False)
+    state_dict = torch.utils.model_zoo.load_url(
+        'http://andrewilyas.com/ImageNet.pt')
+    new_state_dict = {}
+    for k, v in state_dict['model'].items():
+        if k.startswith('module.model.'):
+            new_state_dict[k[13:]] = v
+    resnet50.load_state_dict(new_state_dict)
+    network = nn.Sequential(resnet50.conv1,
+                            resnet50.bn1,
+                            resnet50.relu,
+                            resnet50.maxpool,
+                            resnet50.layer1,
+                            resnet50.layer2,
+                            resnet50.layer3,
+                            resnet50.layer4,
+                            resnet50.avgpool)
+    layer_name_mapping = {4: 'layer_1',
+                          5: 'layer_2',
+                          6: 'layer_3',
+                          7: 'layer_4'}
+    return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _vgg_face_dag(layers):
+    network = torchvision.models.vgg16(num_classes=2622)
+    state_dict = torch.utils.model_zoo.load_url(
+        'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/'
+        'vgg_face_dag.pth')
+    feature_layer_name_mapping = {
+        0: 'conv1_1',
+        2: 'conv1_2',
+        5: 'conv2_1',
+        7: 'conv2_2',
+        10: 'conv3_1',
+        12: 'conv3_2',
+        14: 'conv3_3',
+        17: 'conv4_1',
+        19: 'conv4_2',
+        21: 'conv4_3',
+        24: 'conv5_1',
+        26: 'conv5_2',
+        28: 'conv5_3'}
+    new_state_dict = {}
+    for k, v in feature_layer_name_mapping.items():
+        new_state_dict['features.' + str(k) + '.weight'] = \
+            state_dict[v + '.weight']
+        new_state_dict['features.' + str(k) + '.bias'] = \
+            state_dict[v + '.bias']
+
+    classifier_layer_name_mapping = {
+        0: 'fc6',
+        3: 'fc7',
+        6: 'fc8'}
+    for k, v in classifier_layer_name_mapping.items():
+        new_state_dict['classifier.' + str(k) + '.weight'] = \
+            state_dict[v + '.weight']
+        new_state_dict['classifier.' + str(k) + '.bias'] = \
+            state_dict[v + '.bias']
+
+    network.load_state_dict(new_state_dict)
+
+    class Flatten(nn.Module):
+        def forward(self, x):
+            return x.view(x.shape[0], -1)
+
+    layer_name_mapping = {
+        0: 'conv_1_1',
+        1: 'relu_1_1',
+        2: 'conv_1_2',
+        5: 'conv_2_1',  # 1/2
+        6: 'relu_2_1',
+        7: 'conv_2_2',
+        10: 'conv_3_1',  # 1/4
+        11: 'relu_3_1',
+        12: 'conv_3_2',
+        14: 'conv_3_3',
+        17: 'conv_4_1',  # 1/8
+        18: 'relu_4_1',
+        19: 'conv_4_2',
+        21: 'conv_4_3',
+        24: 'conv_5_1',  # 1/16
+        25: 'relu_5_1',
+        26: 'conv_5_2',
+        28: 'conv_5_3',
+        33: 'fc6',
+        36: 'fc7',
+        39: 'fc8'
+    }
+    seq_layers = []
+    for feature in network.features:
+        seq_layers += [feature]
+    seq_layers += [network.avgpool, Flatten()]
+    for classifier in network.classifier:
+        seq_layers += [classifier]
+    network = nn.Sequential(*seq_layers)
+    return _PerceptualNetwork(network, layer_name_mapping, layers)
diff --git a/imaginaire/losses/weighted_mse.py b/imaginaire/losses/weighted_mse.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e49989a5c3ee8576dcf4dea8a98a16c1911cc9
--- /dev/null
+++ b/imaginaire/losses/weighted_mse.py
@@ -0,0 +1,28 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+
+
+class WeightedMSELoss(nn.Module):
+    r"""Compute Weighted MSE loss"""
+    def __init__(self, reduction='mean'):
+        super(WeightedMSELoss, self).__init__()
+        self.reduction = reduction
+
+    def forward(self, input, target, weight):
+        r"""Return weighted MSE Loss.
+        Args:
+           input (tensor):
+           target (tensor):
+           weight (tensor):
+        Returns:
+           (tensor): Loss value.
+        """
+        if self.reduction == 'mean':
+            loss = torch.mean(weight * (input - target) ** 2)
+        else:
+            loss = torch.sum(weight * (input - target) ** 2)
+        return loss
diff --git a/imaginaire/model_utils/__init__.py b/imaginaire/model_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/model_utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc b/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c45aa8fec0f4180fe7efd93161a67e545e3fb71
Binary files /dev/null and b/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc b/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa52fd3e68a006d480e174ee47d5066eaf79b98f
Binary files /dev/null and b/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc differ
diff --git a/imaginaire/model_utils/fs_vid2vid.py b/imaginaire/model_utils/fs_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..b52faf73d3f37221e9d0f089d1f4a85d5378877d
--- /dev/null
+++ b/imaginaire/model_utils/fs_vid2vid.py
@@ -0,0 +1,865 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""Utils for the few shot vid2vid model."""
+import random
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def resample(image, flow):
+    r"""Resamples an image using the provided flow.
+
+    Args:
+        image (NxCxHxW tensor) : Image to resample.
+        flow (Nx2xHxW tensor) : Optical flow to resample the image.
+    Returns:
+        output (NxCxHxW tensor) : Resampled image.
+    """
+    assert flow.shape[1] == 2
+    b, c, h, w = image.size()
+    grid = get_grid(b, (h, w))
+    flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0),
+                      flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)
+    final_grid = (grid + flow).permute(0, 2, 3, 1)
+    try:
+        output = F.grid_sample(image, final_grid, mode='bilinear',
+                               padding_mode='border', align_corners=True)
+    except Exception:
+        output = F.grid_sample(image, final_grid, mode='bilinear',
+                               padding_mode='border')
+    return output
+
+
+def get_grid(batchsize, size, minval=-1.0, maxval=1.0):
+    r"""Get a grid ranging [-1, 1] of 2D/3D coordinates.
+
+    Args:
+        batchsize (int) : Batch size.
+        size (tuple) : (height, width) or (depth, height, width).
+        minval (float) : minimum value in returned grid.
+        maxval (float) : maximum value in returned grid.
+    Returns:
+        t_grid (4D tensor) : Grid of coordinates.
+    """
+    if len(size) == 2:
+        rows, cols = size
+    elif len(size) == 3:
+        deps, rows, cols = size
+    else:
+        raise ValueError('Dimension can only be 2 or 3.')
+    x = torch.linspace(minval, maxval, cols)
+    x = x.view(1, 1, 1, cols)
+    x = x.expand(batchsize, 1, rows, cols)
+
+    y = torch.linspace(minval, maxval, rows)
+    y = y.view(1, 1, rows, 1)
+    y = y.expand(batchsize, 1, rows, cols)
+
+    t_grid = torch.cat([x, y], dim=1)
+
+    if len(size) == 3:
+        z = torch.linspace(minval, maxval, deps)
+        z = z.view(1, 1, deps, 1, 1)
+        z = z.expand(batchsize, 1, deps, rows, cols)
+
+        t_grid = t_grid.unsqueeze(2).expand(batchsize, 2, deps, rows, cols)
+        t_grid = torch.cat([t_grid, z], dim=1)
+
+    t_grid.requires_grad = False
+    return t_grid.to('cuda')
+
+
+def pick_image(images, idx):
+    r"""Pick the image among images according to idx.
+
+    Args:
+        images (B x N x C x H x W tensor or list of tensors) : N images.
+        idx (B tensor) : indices to select.
+    Returns:
+        image (B x C x H x W) : Selected images.
+    """
+    if type(images) == list:
+        return [pick_image(r, idx) for r in images]
+    if idx is None:
+        return images[:, 0]
+    elif type(idx) == int:
+        return images[:, idx]
+    idx = idx.long().view(-1, 1, 1, 1, 1)
+    image = images.gather(1, idx.expand_as(images)[:, 0:1])[:, 0]
+    return image
+
+
+def crop_face_from_data(cfg, is_inference, data):
+    r"""Crop the face regions in input data and resize to the target size.
+    This is for training face datasets.
+
+    Args:
+        cfg (obj): Data configuration.
+        is_inference (bool): Is doing inference or not.
+        data (dict): Input data.
+    Returns:
+        data (dict): Cropped data.
+    """
+    label = data['label'] if 'label' in data else None
+    image = data['images']
+    landmarks = data['landmarks-dlib68_xy']
+    ref_labels = data['few_shot_label'] if 'few_shot_label' in data else None
+    ref_images = data['few_shot_images']
+    ref_landmarks = data['few_shot_landmarks-dlib68_xy']
+    img_size = image.shape[-2:]
+    h, w = cfg.output_h_w.split(',')
+    h, w = int(h), int(w)
+
+    # When doing inference, need to sync common attributes like crop coodinates
+    # between different workers, so all workers crop the same region.
+    if 'common_attr' in data and 'crop_coords' in data['common_attr']:
+        # Has been computed before, reusing the previous one.
+        crop_coords, ref_crop_coords = data['common_attr']['crop_coords']
+    else:
+        # Is the first frame, need to compute the bbox.
+        ref_crop_coords, scale = get_face_bbox_for_data(
+            ref_landmarks[0], img_size, None, is_inference)
+        crop_coords, _ = get_face_bbox_for_data(
+            landmarks[0], img_size, scale, is_inference)
+
+    # Crop the images according to the bbox and resize them to target size.
+    label, image = crop_and_resize([label, image], crop_coords, (h, w))
+    ref_labels, ref_images = crop_and_resize([ref_labels, ref_images],
+                                             ref_crop_coords, (h, w))
+
+    data['images'], data['few_shot_images'] = image, ref_images
+    if label is not None:
+        data['label'], data['few_shot_label'] = label, ref_labels
+    if is_inference:
+        if 'common_attr' not in data:
+            data['common_attr'] = dict()
+        data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords
+    return data
+
+
+def get_face_bbox_for_data(keypoints, orig_img_size, scale, is_inference):
+    r"""Get the bbox coordinates for face region.
+
+    Args:
+        keypoints (Nx2 tensor): Facial landmarks.
+        orig_img_size (int tuple): Height and width of the input image size.
+        scale (float): When training, randomly scale the crop size for
+        augmentation.
+        is_inference (bool): Is doing inference or not.
+    Returns:
+        crop_coords (list of int): bbox for face region.
+        scale (float): Also returns scale to ensure reference and target frames
+        are croppped using the same scale.
+    """
+    min_y, max_y = int(keypoints[:, 1].min()), int(keypoints[:, 1].max())
+    min_x, max_x = int(keypoints[:, 0].min()), int(keypoints[:, 0].max())
+    x_cen, y_cen = (min_x + max_x) // 2, (min_y + max_y) // 2
+    H, W = orig_img_size
+    w = h = (max_x - min_x)
+    if not is_inference:
+        # During training, randomly jitter the cropping position by offset
+        # amount for augmentation.
+        offset_max = 0.2
+        offset = [np.random.uniform(-offset_max, offset_max),
+                  np.random.uniform(-offset_max, offset_max)]
+        # Also augment the crop size.
+        if scale is None:
+            scale_max = 0.2
+            scale = [np.random.uniform(1 - scale_max, 1 + scale_max),
+                     np.random.uniform(1 - scale_max, 1 + scale_max)]
+        w *= scale[0]
+        h *= scale[1]
+        x_cen += int(offset[0] * w)
+        y_cen += int(offset[1] * h)
+
+    # Get the cropping coordinates.
+    x_cen = max(w, min(W - w, x_cen))
+    y_cen = max(h * 1.25, min(H - h * 0.75, y_cen))
+
+    min_x = x_cen - w
+    min_y = y_cen - h * 1.25
+    max_x = min_x + w * 2
+    max_y = min_y + h * 2
+
+    crop_coords = [min_y, max_y, min_x, max_x]
+    return [int(x) for x in crop_coords], scale
+
+
+def crop_person_from_data(cfg, is_inference, data):
+    r"""Crop the person regions in data and resize to the target size.
+    This is for training full body datasets.
+
+    Args:
+        cfg (obj): Data configuration.
+        is_inference (bool): Is doing inference or not.
+        data (dict): Input data.
+    Returns:
+        data (dict): Cropped data.
+    """
+    label = data['label']
+    image = data['images']
+    use_few_shot = 'few_shot_label' in data
+    if use_few_shot:
+        ref_labels = data['few_shot_label']
+        ref_images = data['few_shot_images']
+
+    img_size = image.shape[-2:]
+    output_h, output_w = cfg.output_h_w.split(',')
+    output_h, output_w = int(output_h), int(output_w)
+    output_aspect_ratio = output_w / output_h
+
+    if 'human_instance_maps' in data:
+        # Remove other people in the DensePose map except for the current
+        # target.
+        label = remove_other_ppl(label, data['human_instance_maps'])
+        if use_few_shot:
+            ref_labels = remove_other_ppl(ref_labels,
+                                          data['few_shot_human_instance_maps'])
+
+    # Randomly jitter the crop position by offset amount for augmentation.
+    offset = ref_offset = None
+    if not is_inference:
+        offset = np.random.randn(2) * 0.05
+        offset = np.minimum(1, np.maximum(-1, offset))
+        ref_offset = np.random.randn(2) * 0.02
+        ref_offset = np.minimum(1, np.maximum(-1, ref_offset))
+
+    # Randomly scale the crop size for augmentation.
+    # Final cropped size = person height * scale.
+    scale = ref_scale = 1.5
+    if not is_inference:
+        scale = min(2, max(1, scale + np.random.randn() * 0.05))
+        ref_scale = min(2, max(1, ref_scale + np.random.randn() * 0.02))
+
+    # When doing inference, need to sync common attributes like crop coodinates
+    # between different workers, so all workers crop the same region.
+    if 'common_attr' in data:
+        # Has been computed before, reusing the previous one.
+        crop_coords, ref_crop_coords = data['common_attr']['crop_coords']
+    else:
+        # Is the first frame, need to compute the bbox.
+        crop_coords = get_person_bbox_for_data(label, img_size, scale,
+                                               output_aspect_ratio, offset)
+        if use_few_shot:
+            ref_crop_coords = get_person_bbox_for_data(
+                ref_labels, img_size, ref_scale,
+                output_aspect_ratio, ref_offset)
+        else:
+            ref_crop_coords = None
+
+    # Crop the images according to the bbox and resize them to target size.
+    label = crop_and_resize(label, crop_coords, (output_h, output_w), 'nearest')
+    image = crop_and_resize(image, crop_coords, (output_h, output_w))
+    if use_few_shot:
+        ref_labels = crop_and_resize(ref_labels, ref_crop_coords,
+                                     (output_h, output_w), 'nearest')
+        ref_images = crop_and_resize(ref_images, ref_crop_coords,
+                                     (output_h, output_w))
+
+    data['label'], data['images'] = label, image
+    if use_few_shot:
+        data['few_shot_label'], data['few_shot_images'] = ref_labels, ref_images
+    if 'human_instance_maps' in data:
+        del data['human_instance_maps']
+    if 'few_shot_human_instance_maps' in data:
+        del data['few_shot_human_instance_maps']
+    if is_inference:
+        data['common_attr'] = dict()
+        data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords
+
+    return data
+
+
+def get_person_bbox_for_data(pose_map, orig_img_size, scale=1.5,
+                             crop_aspect_ratio=1, offset=None):
+    r"""Get the bbox (pixel coordinates) to crop for person body region.
+
+    Args:
+        pose_map (NxCxHxW tensor): Input pose map.
+        orig_img_size (int tuple): Height and width of the input image size.
+        scale (float): When training, randomly scale the crop size for
+        augmentation.
+        crop_aspect_ratio (float): Output aspect ratio,
+        offset (list of float): Offset for crop position.
+    Returns:
+        crop_coords (list of int): bbox for body region.
+    """
+    H, W = orig_img_size
+    assert pose_map.dim() == 4
+    nonzero_indices = (pose_map[:, :3] > 0).nonzero(as_tuple=False)
+    if nonzero_indices.size(0) == 0:
+        bw = int(H * crop_aspect_ratio // 2)
+        return [0, H, W // 2 - bw, W // 2 + bw]
+
+    y_indices, x_indices = nonzero_indices[:, 2], nonzero_indices[:, 3]
+    y_min, y_max = y_indices.min().item(), y_indices.max().item()
+    x_min, x_max = x_indices.min().item(), x_indices.max().item()
+    y_cen = int(y_min + y_max) // 2
+    x_cen = int(x_min + x_max) // 2
+    y_len = y_max - y_min
+    x_len = x_max - x_min
+
+    # bh, bw: half of height / width of final cropped size.
+    bh = int(min(H, max(H // 2, y_len * scale))) // 2
+    bh = max(bh, int(x_len * scale / crop_aspect_ratio) // 2)
+    bw = int(bh * crop_aspect_ratio)
+
+    # Randomly offset the cropped position for augmentation.
+    if offset is not None:
+        x_cen += int(offset[0] * bw)
+        y_cen += int(offset[1] * bh)
+    x_cen = max(bw, min(W - bw, x_cen))
+    y_cen = max(bh, min(H - bh, y_cen))
+
+    return [(y_cen - bh), (y_cen + bh), (x_cen - bw), (x_cen + bw)]
+
+
+def crop_and_resize(img, coords, size=None, method='bilinear'):
+    r"""Crop the image using the given coordinates and resize to target size.
+
+    Args:
+        img (tensor or list of tensors): Input image.
+        coords (list of int): Pixel coordinates to crop.
+        size (list of int): Output size.
+        method (str): Interpolation method.
+    Returns:
+        img (tensor or list of tensors): Output image.
+    """
+    if isinstance(img, list):
+        return [crop_and_resize(x, coords, size, method) for x in img]
+    if img is None:
+        return None
+    min_y, max_y, min_x, max_x = coords
+
+    img = img[:, :, min_y:max_y, min_x:max_x]
+    if size is not None:
+        if method == 'nearest':
+            img = F.interpolate(img, size=size, mode=method)
+        else:
+            img = F.interpolate(img, size=size, mode=method,
+                                align_corners=False)
+    return img
+
+
+def remove_other_ppl(labels, densemasks):
+    r"""Remove other people in the label map except for the current target
+    by looking at the id in the densemask map.
+
+    Args:
+        labels (NxCxHxW tensor): Input labels.
+        densemasks (Nx1xHxW tensor): Densemask maps.
+    Returns:
+        labels (NxCxHxW tensor): Output labels.
+    """
+    densemasks = densemasks[:, 0:1] * 255
+    for idx in range(labels.shape[0]):
+        label, densemask = labels[idx], densemasks[idx]
+        # Get OpenPose and find the person id in Densemask that has the most
+        # overlap with the person in OpenPose result.
+        openpose = label[3:]
+        valid = (openpose[0] > 0) | (openpose[1] > 0) | (openpose[2] > 0)
+        dp_valid = densemask[valid.unsqueeze(0)]
+        if dp_valid.shape[0]:
+            ind = np.bincount(dp_valid).argmax()
+            # Remove all other people that have different indices.
+            label = label * (densemask == ind).float()
+        labels[idx] = label
+    return labels
+
+
+def select_object(data, obj_indices=None):
+    r"""Select the object/person in the dict according to the object index.
+    Currently it's used to select the target person in OpenPose dict.
+
+    Args:
+        data (dict): Input data.
+        obj_indices (list of int): Indices for the objects to select.
+    Returns:
+        data (dict): Output data.
+    """
+    op_keys = ['poses-openpose', 'captions-clip']
+    for op_key in op_keys:
+        if op_key in data:
+            for i in range(len(data[op_key])):
+                # data[op_key] is a list of dicts for different frames.
+                # people = data[op_key][i]['people']
+                people = data[op_key][i]
+                # "people" is a list of people dicts found by OpenPose. We will
+                # use the obj_index to get the target person from the list, and
+                # write it back to the dict.
+                # data[op_key][i]['people'] = [people[obj_indices[i]]]
+                if obj_indices is not None:
+                    data[op_key][i] = people[obj_indices[i]]
+                else:
+                    if op_key == 'poses-openpose':
+                        data[op_key][i] = people[0]
+                    else:
+                        idx = random.randint(0, len(people) - 1)
+                        data[op_key][i] = people[idx]
+    return data
+
+
+def concat_frames(prev, now, n_frames):
+    r"""Concat previous and current frames and only keep the latest $(n_frames).
+    If concatenated frames are longer than $(n_frames), drop the oldest one.
+
+    Args:
+        prev (NxTxCxHxW tensor): Tensor for previous frames.
+        now (NxCxHxW tensor): Tensor for current frame.
+        n_frames (int): Max number of frames to store.
+    Returns:
+        result (NxTxCxHxW tensor): Updated tensor.
+    """
+    now = now.unsqueeze(1)
+    if prev is None:
+        return now
+    if prev.shape[1] == n_frames:
+        prev = prev[:, 1:]
+    return torch.cat([prev, now], dim=1)
+
+
+def combine_fg_mask(fg_mask, ref_fg_mask, has_fg):
+    r"""Get the union of target and reference foreground masks.
+    Args:
+        fg_mask (tensor): Foreground mask for target image.
+        ref_fg_mask (tensor): Foreground mask for reference image.
+        has_fg (bool): Whether the image can be classified into fg/bg.
+    Returns:
+        output (tensor or int): Combined foreground mask.
+    """
+    return ((fg_mask > 0) | (ref_fg_mask > 0)).float() if has_fg else 1
+
+
+def get_fg_mask(densepose_map, has_fg):
+    r"""Obtain the foreground mask for pose sequences, which only includes
+    the human. This is done by looking at the body part map from DensePose.
+
+    Args:
+        densepose_map (NxCxHxW tensor): DensePose map.
+        has_fg (bool): Whether data has foreground or not.
+    Returns:
+        mask (Nx1xHxW tensor): fg mask.
+    """
+    if type(densepose_map) == list:
+        return [get_fg_mask(label, has_fg) for label in densepose_map]
+    if not has_fg or densepose_map is None:
+        return 1
+    if densepose_map.dim() == 5:
+        densepose_map = densepose_map[:, 0]
+    # Get the body part map from DensePose.
+    mask = densepose_map[:, 2:3]
+
+    # Make the mask slightly larger.
+    mask = torch.nn.MaxPool2d(15, padding=7, stride=1)(mask)
+    mask = (mask > -1).float()
+    return mask
+
+
+def get_part_mask(densepose_map):
+    r"""Obtain mask of different body parts of humans. This is done by
+    looking at the body part map from DensePose.
+
+    Args:
+        densepose_map (NxCxHxW tensor): DensePose map.
+    Returns:
+        mask (NxKxHxW tensor): Body part mask, where K is the number of parts.
+    """
+    # Groups of body parts. Each group contains IDs of body part labels in
+    # DensePose. The 9 groups here are: background, torso, hands, feet,
+    # upper legs, lower legs, upper arms, lower arms, head.
+    part_groups = [[0], [1, 2], [3, 4], [5, 6], [7, 9, 8, 10], [11, 13, 12, 14],
+                   [15, 17, 16, 18], [19, 21, 20, 22], [23, 24]]
+    n_parts = len(part_groups)
+
+    need_reshape = densepose_map.dim() == 4
+    if need_reshape:
+        bo, t, h, w = densepose_map.size()
+        densepose_map = densepose_map.view(-1, h, w)
+    b, h, w = densepose_map.size()
+    part_map = (densepose_map / 2 + 0.5) * 24
+    assert (part_map >= 0).all() and (part_map < 25).all()
+
+    mask = torch.cuda.ByteTensor(b, n_parts, h, w).fill_(0)
+    for i in range(n_parts):
+        for j in part_groups[i]:
+            # Account for numerical errors.
+            mask[:, i] = mask[:, i] | (
+                (part_map > j - 0.1) & (part_map < j + 0.1)).byte()
+    if need_reshape:
+        mask = mask.view(bo, t, -1, h, w)
+    return mask.float()
+
+
+def get_face_mask(densepose_map):
+    r"""Obtain mask of faces.
+    Args:
+        densepose_map (3D or 4D tensor): DensePose map.
+    Returns:
+        mask (3D or 4D tensor): Face mask.
+    """
+    need_reshape = densepose_map.dim() == 4
+    if need_reshape:
+        bo, t, h, w = densepose_map.size()
+        densepose_map = densepose_map.view(-1, h, w)
+
+    b, h, w = densepose_map.size()
+    part_map = (densepose_map / 2 + 0.5) * 24
+    assert (part_map >= 0).all() and (part_map < 25).all()
+    if densepose_map.is_cuda:
+        mask = torch.cuda.ByteTensor(b, h, w).fill_(0)
+    else:
+        mask = torch.ByteTensor(b, h, w).fill_(0)
+    for j in [23, 24]:
+        mask = mask | ((part_map > j - 0.1) & (part_map < j + 0.1)).byte()
+    if need_reshape:
+        mask = mask.view(bo, t, h, w)
+    return mask.float()
+
+
+def extract_valid_pose_labels(pose_map, pose_type, remove_face_labels,
+                              do_remove=True):
+    r"""Remove some labels (e.g. face regions) in the pose map if necessary.
+
+    Args:
+        pose_map (3D, 4D or 5D tensor): Input pose map.
+        pose_type (str): 'both' or 'open'.
+        remove_face_labels (bool): Whether to remove labels for the face region.
+        do_remove (bool): Do remove face labels.
+    Returns:
+        pose_map (3D, 4D or 5D tensor): Output pose map.
+    """
+    if pose_map is None:
+        return pose_map
+    if type(pose_map) == list:
+        return [extract_valid_pose_labels(p, pose_type, remove_face_labels,
+                                          do_remove) for p in pose_map]
+
+    orig_dim = pose_map.dim()
+    assert (orig_dim >= 3 and orig_dim <= 5)
+    if orig_dim == 3:
+        pose_map = pose_map.unsqueeze(0).unsqueeze(0)
+    elif orig_dim == 4:
+        pose_map = pose_map.unsqueeze(0)
+
+    if pose_type == 'open':
+        # If input is only openpose, remove densepose part.
+        pose_map = pose_map[:, :, 3:]
+
+    elif remove_face_labels and do_remove:
+        # Remove face part for densepose input.
+        densepose, openpose = pose_map[:, :, :3], pose_map[:, :, 3:]
+        face_mask = get_face_mask(pose_map[:, :, 2]).unsqueeze(2)
+        pose_map = torch.cat([densepose * (1 - face_mask) - face_mask,
+                              openpose], dim=2)
+
+    if orig_dim == 3:
+        pose_map = pose_map[0, 0]
+    elif orig_dim == 4:
+        pose_map = pose_map[0]
+    return pose_map
+
+
+def normalize_faces(keypoints, ref_keypoints,
+                    dist_scale_x=None, dist_scale_y=None):
+    r"""Normalize face keypoints w.r.t. the reference face keypoints.
+
+    Args:
+        keypoints (Kx2 numpy array): target facial keypoints.
+        ref_keypoints (Kx2 numpy array): reference facial keypoints.
+    Returns:
+        keypoints (Kx2 numpy array): normalized facial keypoints.
+    """
+    if keypoints.shape[0] == 68:
+        central_keypoints = [8]
+        add_upper_face = False
+        part_list = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12],
+                     [5, 11], [6, 10], [7, 9, 8],
+                     [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
+                     [27], [28], [29], [30], [31, 35], [32, 34], [33],
+                     [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
+                     [48, 54], [49, 53], [50, 52], [51], [55, 59], [56, 58],
+                     [57],
+                     [60, 64], [61, 63], [62], [65, 67], [66]
+                     ]
+        if add_upper_face:
+            part_list += [[68, 82], [69, 81], [70, 80], [71, 79], [72, 78],
+                          [73, 77], [74, 76, 75]]
+    elif keypoints.shape[0] == 126:
+        central_keypoints = [16]
+        part_list = [[i] for i in range(126)]
+    else:
+        raise ValueError('Input keypoints type not supported.')
+
+    face_cen = np.mean(keypoints[central_keypoints, :], axis=0)
+    ref_face_cen = np.mean(ref_keypoints[central_keypoints, :], axis=0)
+
+    def get_mean_dists(pts, face_cen):
+        r"""Get the mean xy distances of keypoints wrt face center."""
+        mean_dists_x, mean_dists_y = [], []
+        pts_cen = np.mean(pts, axis=0)
+        for p, pt in enumerate(pts):
+            mean_dists_x.append(np.linalg.norm(pt - pts_cen))
+            mean_dists_y.append(np.linalg.norm(pts_cen - face_cen))
+        mean_dist_x = sum(mean_dists_x) / len(mean_dists_x) + 1e-3
+        mean_dist_y = sum(mean_dists_y) / len(mean_dists_y) + 1e-3
+        return mean_dist_x, mean_dist_y
+
+    if dist_scale_x is None:
+        dist_scale_x, dist_scale_y = [None] * len(part_list), \
+                                     [None] * len(part_list)
+
+    for i, pts_idx in enumerate(part_list):
+        pts = keypoints[pts_idx]
+        if dist_scale_x[i] is None:
+            ref_pts = ref_keypoints[pts_idx]
+            mean_dist_x, mean_dist_y = get_mean_dists(pts, face_cen)
+            ref_dist_x, ref_dist_y = get_mean_dists(ref_pts, ref_face_cen)
+
+            dist_scale_x[i] = ref_dist_x / mean_dist_x
+            dist_scale_y[i] = ref_dist_y / mean_dist_y
+
+        pts_cen = np.mean(pts, axis=0)
+        pts = (pts - pts_cen) * dist_scale_x[i] + \
+              (pts_cen - face_cen) * dist_scale_y[i] + face_cen
+        keypoints[pts_idx] = pts
+    return keypoints, [dist_scale_x, dist_scale_y]
+
+
+def crop_face_from_output(data_cfg, image, input_label, crop_smaller=0):
+    r"""Crop out the face region of the image (and resize if necessary to feed
+    into generator/discriminator).
+
+    Args:
+        data_cfg (obj): Data configuration.
+        image (NxC1xHxW tensor or list of tensors): Image to crop.
+        input_label (NxC2xHxW tensor): Input label map.
+        crop_smaller (int): Number of pixels to crop slightly smaller region.
+    Returns:
+        output (NxC1xHxW tensor or list of tensors): Cropped image.
+    """
+    if type(image) == list:
+        return [crop_face_from_output(data_cfg, im, input_label, crop_smaller)
+                for im in image]
+
+    output = None
+    face_size = image.shape[-2] // 32 * 8
+    for i in range(input_label.size(0)):
+        ys, ye, xs, xe = get_face_bbox_for_output(data_cfg,
+                                                  input_label[i:i + 1],
+                                                  crop_smaller=crop_smaller)
+        output_i = F.interpolate(image[i:i + 1, -3:, ys:ye, xs:xe],
+                                 size=(face_size, face_size), mode='bilinear',
+                                 align_corners=True)
+        # output_i = image[i:i + 1, -3:, ys:ye, xs:xe]
+        output = torch.cat([output, output_i]) if i != 0 else output_i
+    return output
+
+
+def get_face_bbox_for_output(data_cfg, pose, crop_smaller=0):
+    r"""Get pixel coordinates of the face bounding box.
+
+    Args:
+        data_cfg (obj): Data configuration.
+        pose (NxCxHxW tensor): Pose label map.
+        crop_smaller (int): Number of pixels to crop slightly smaller region.
+    Returns:
+        output (list of int): Face bbox.
+    """
+    if pose.dim() == 3:
+        pose = pose.unsqueeze(0)
+    elif pose.dim() == 5:
+        pose = pose[-1, -1:]
+    _, _, h, w = pose.size()
+
+    use_openpose = 'pose_maps-densepose' not in data_cfg.input_labels
+    if use_openpose:  # Use openpose face keypoints to identify face region.
+        for input_type in data_cfg.input_types:
+            if 'poses-openpose' in input_type:
+                num_ch = input_type['poses-openpose'].num_channels
+        if num_ch > 3:
+            face = (pose[:, -1] > 0).nonzero(as_tuple=False)
+        else:
+            raise ValueError('Not implemented yet.')
+    else:  # Use densepose labels.
+        face = (pose[:, 2] > 0.9).nonzero(as_tuple=False)
+
+    ylen = xlen = h // 32 * 8
+    if face.size(0):
+        y, x = face[:, 1], face[:, 2]
+        ys, ye = y.min().item(), y.max().item()
+        xs, xe = x.min().item(), x.max().item()
+        if use_openpose:
+            xc, yc = (xs + xe) // 2, (ys * 3 + ye * 2) // 5
+            ylen = int((xe - xs) * 2.5)
+        else:
+            xc, yc = (xs + xe) // 2, (ys + ye) // 2
+            ylen = int((ye - ys) * 1.25)
+        ylen = xlen = min(w, max(32, ylen))
+        yc = max(ylen // 2, min(h - 1 - ylen // 2, yc))
+        xc = max(xlen // 2, min(w - 1 - xlen // 2, xc))
+    else:
+        yc = h // 4
+        xc = w // 2
+
+    ys, ye = yc - ylen // 2, yc + ylen // 2
+    xs, xe = xc - xlen // 2, xc + xlen // 2
+    if crop_smaller != 0:  # Crop slightly smaller region inside face.
+        ys += crop_smaller
+        xs += crop_smaller
+        ye -= crop_smaller
+        xe -= crop_smaller
+    return [ys, ye, xs, xe]
+
+
+def crop_hand_from_output(data_cfg, image, input_label):
+    r"""Crop out the hand region of the image.
+
+    Args:
+        data_cfg (obj): Data configuration.
+        image (NxC1xHxW tensor or list of tensors): Image to crop.
+        input_label (NxC2xHxW tensor): Input label map.
+    Returns:
+        output (NxC1xHxW tensor or list of tensors): Cropped image.
+    """
+    if type(image) == list:
+        return [crop_hand_from_output(data_cfg, im, input_label)
+                for im in image]
+
+    output = None
+    for i in range(input_label.size(0)):
+        coords = get_hand_bbox_for_output(data_cfg, input_label[i:i + 1])
+        if coords:
+            for coord in coords:
+                ys, ye, xs, xe = coord
+                output_i = image[i:i + 1, -3:, ys:ye, xs:xe]
+                output = torch.cat([output, output_i]) \
+                    if output is not None else output_i
+    return output
+
+
+def get_hand_bbox_for_output(data_cfg, pose):
+    r"""Get coordinates of the hand bounding box.
+
+    Args:
+        data_cfg (obj): Data configuration.
+        pose (NxCxHxW tensor): Pose label map.
+    Returns:
+        output (list of int): Hand bbox.
+    """
+    if pose.dim() == 3:
+        pose = pose.unsqueeze(0)
+    elif pose.dim() == 5:
+        pose = pose[-1, -1:]
+    _, _, h, w = pose.size()
+    ylen = xlen = h // 64 * 8
+
+    coords = []
+    colors = [[0.95, 0.5, 0.95], [0.95, 0.95, 0.5]]
+    for i, color in enumerate(colors):
+        if pose.shape[1] > 6:  # Using one-hot encoding for openpose.
+            idx = -3 if i == 0 else -2
+            hand = (pose[:, idx] == 1).nonzero(as_tuple=False)
+        else:
+            raise ValueError('Not implemented yet.')
+        if hand.size(0):
+            y, x = hand[:, 1], hand[:, 2]
+            ys, ye, xs, xe = y.min().item(), y.max().item(), \
+                x.min().item(), x.max().item()
+            xc, yc = (xs + xe) // 2, (ys + ye) // 2
+            yc = max(ylen // 2, min(h - 1 - ylen // 2, yc))
+            xc = max(xlen // 2, min(w - 1 - xlen // 2, xc))
+            ys, ye, xs, xe = yc - ylen // 2, yc + ylen // 2, \
+                xc - xlen // 2, xc + xlen // 2
+            coords.append([ys, ye, xs, xe])
+    return coords
+
+
+def pre_process_densepose(pose_cfg, pose_map, is_infer=False):
+    r"""Pre-process the DensePose part of input label map.
+
+    Args:
+        pose_cfg (obj): Pose data configuration.
+        pose_map (NxCxHxW tensor): Pose label map.
+        is_infer (bool): Is doing inference.
+    Returns:
+        pose_map (NxCxHxW tensor): Processed pose label map.
+    """
+    part_map = pose_map[:, :, 2] * 255  # should be within [0-24]
+    assert (part_map >= 0).all() and (part_map < 25).all()
+
+    # Randomly drop some body part during training.
+    if not is_infer:
+        random_drop_prob = getattr(pose_cfg, 'random_drop_prob', 0)
+    else:
+        random_drop_prob = 0
+    if random_drop_prob > 0:
+        densepose_map = pose_map[:, :, :3]
+        for part_id in range(1, 25):
+            if (random.random() < random_drop_prob):
+                part_mask = abs(part_map - part_id) < 0.1
+                densepose_map[part_mask.unsqueeze(2).expand_as(
+                    densepose_map)] = 0
+        pose_map[:, :, :3] = densepose_map
+
+    # Renormalize the DensePose channel from [0, 24] to [0, 255].
+    pose_map[:, :, 2] = pose_map[:, :, 2] * (255 / 24)
+    # Normalize from [0, 1] to [-1, 1].
+    pose_map = pose_map * 2 - 1
+    return pose_map
+
+
+def random_roll(tensors):
+    r"""Randomly roll the input tensors along x and y dimensions. Also randomly
+    flip the tensors.
+
+    Args:
+        tensors (list of 4D tensors): Input tensors.
+    Returns:
+        output (list of 4D tensors): Rolled tensors.
+    """
+    h, w = tensors[0].shape[2:]
+    ny = np.random.choice([np.random.randint(h//16),
+                           h-np.random.randint(h//16)])
+    nx = np.random.choice([np.random.randint(w//16),
+                           w-np.random.randint(w//16)])
+    flip = np.random.rand() > 0.5
+    return [roll(t, ny, nx, flip) for t in tensors]
+
+
+def roll(t, ny, nx, flip=False):
+    r"""Roll and flip the tensor by specified amounts.
+
+    Args:
+        t (4D tensor): Input tensor.
+        ny (int): Amount to roll along y dimension.
+        nx (int): Amount to roll along x dimension.
+        flip (bool): Whether to flip input.
+    Returns:
+        t (4D tensor): Output tensor.
+    """
+    t = torch.cat([t[:, :, -ny:], t[:, :, :-ny]], dim=2)
+    t = torch.cat([t[:, :, :, -nx:], t[:, :, :, :-nx]], dim=3)
+    if flip:
+        t = torch.flip(t, dims=[3])
+    return t
+
+
+def detach(output):
+    r"""Detach tensors in the dict.
+
+    Args:
+        output (dict): Output dict.
+    Returns:
+        output (dict): Detached output dict.
+    """
+    if type(output) == dict:
+        new_dict = dict()
+        for k, v in output.items():
+            new_dict[k] = detach(v)
+        return new_dict
+    elif type(output) == torch.Tensor:
+        return output.detach()
+    return output
diff --git a/imaginaire/model_utils/gancraft/camctl.py b/imaginaire/model_utils/gancraft/camctl.py
new file mode 100644
index 0000000000000000000000000000000000000000..26e4ab674b7f6b44d484d06c661267ed7ce69d56
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/camctl.py
@@ -0,0 +1,640 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import torch
+
+
+class EvalCameraController:
+    def __init__(self, voxel, maxstep=128, pattern=0, cam_ang=73, smooth_decay_multiplier=1.0):
+        self.voxel = voxel
+        self.maxstep = maxstep
+        self.camera_poses = []  # ori, dir, up, f
+        circle = torch.linspace(0, 2*np.pi, steps=maxstep)
+        size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
+        # Shrink the circle a bit.
+        shift = size * 0.2
+        size = size * 0.8
+
+        if pattern == 0:
+            height_history = []
+            # Calculate smooth height.
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    70,
+                    torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+                height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+            # Filtfilt
+            height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    70,
+                    torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+                farpoint[0] = height_history[i]
+
+                nearpoint = torch.tensor([
+                    60,
+                    torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift])
+                cam_ori = self.voxel.world2local(farpoint)
+                cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+                cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+                cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2))  # about 24mm fov
+
+                self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+        elif pattern == 1:
+            zoom = torch.linspace(1.0, 0.25, steps=maxstep)
+            height_history = []
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    90,
+                    torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+                height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+            height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    90,
+                    torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+                farpoint[0] = height_history[i]
+
+                nearpoint = torch.tensor([
+                    60,
+                    torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift])
+                cam_ori = self.voxel.world2local(farpoint)
+                cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+                cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+                cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)*zoom[i])  # about 24mm fov
+
+                self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+        elif pattern == 2:
+            move = torch.linspace(1.0, 0.2, steps=maxstep)
+            height_history = []
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    90,
+                    torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+            height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    90,
+                    torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                farpoint[0] = height_history[i]
+
+                nearpoint = torch.tensor([
+                    60,
+                    torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+                cam_ori = self.voxel.world2local(farpoint)
+                cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+                cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+                cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2))  # about 24mm fov
+
+                self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+        elif pattern == 3:
+            move = torch.linspace(0.75, 0.2, steps=maxstep)
+            height_history = []
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    70,
+                    torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+            height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    70,
+                    torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                farpoint[0] = height_history[i]
+
+                nearpoint = torch.tensor([
+                    60,
+                    torch.sin(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(2)/2 + shift])
+                cam_ori = self.voxel.world2local(farpoint)
+                cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+                cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+                cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2))  # about 24mm fov
+
+                self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+        elif pattern == 4:
+            move = torch.linspace(1.0, 0.5, steps=maxstep)
+            height_history = []
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    90,
+                    torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+            height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    90,
+                    torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                farpoint[0] = height_history[i]
+
+                nearpoint = torch.tensor([
+                    60,
+                    torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+                cam_ori = self.voxel.world2local(farpoint)
+                cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+                cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+                cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2))  # about 24mm fov
+
+                self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+        # look outward
+        elif pattern == 5:
+            move = torch.linspace(1.0, 0.5, steps=maxstep)
+            height_history = []
+            for i in range(maxstep):
+                nearpoint = torch.tensor([
+                    60,
+                    torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                height_history.append(self._get_height(nearpoint[1], nearpoint[2], nearpoint[0]))
+
+            height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+            for i in range(maxstep):
+                nearpoint = torch.tensor([
+                    60,
+                    torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                nearpoint[0] = height_history[i]
+
+                farpoint = torch.tensor([
+                    60,
+                    torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+                cam_ori = self.voxel.world2local(nearpoint)
+                cam_dir = self.voxel.world2local(farpoint - nearpoint, is_vec=True)
+                cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+                cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2))  # about 24mm fov
+
+                self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+        # Rise
+        elif pattern == 6:
+            shift = 0
+            lift = torch.linspace(0.0, 200.0, steps=maxstep)
+            zoom = torch.linspace(0.8, 1.6, steps=maxstep)
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    80+lift[i],
+                    torch.sin(circle[i]/4)*size*0.2 + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i]/4)*size*0.2 + voxel.voxel_t.size(2)/2 + shift])
+
+                farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+                nearpoint = torch.tensor([
+                    65,
+                    torch.sin(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift,
+                    torch.cos(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift])
+                cam_ori = self.voxel.world2local(farpoint)
+                cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+                cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+                cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i])  # about 24mm fov
+
+                self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+        # 45deg
+        elif pattern == 7:
+            rad = torch.tensor([np.deg2rad(45).astype(np.float32)])
+            size = 1536
+            for i in range(maxstep):
+                farpoint = torch.tensor([
+                    61+size,
+                    torch.sin(rad)*size + voxel.voxel_t.size(1)/2,
+                    torch.cos(rad)*size + voxel.voxel_t.size(2)/2])
+
+                nearpoint = torch.tensor([
+                    61,
+                    voxel.voxel_t.size(1)/2,
+                    voxel.voxel_t.size(2)/2])
+                cam_ori = self.voxel.world2local(farpoint)
+                cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+                cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+                cam_f = 0.5/np.tan(np.deg2rad(19.5/2))  # about 50mm fov
+
+                self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+    def _get_height(self, loc0, loc1, minheight):
+        loc0 = int(loc0)
+        loc1 = int(loc1)
+        height = minheight
+        for dx in range(-3, 4):
+            for dy in range(-3, 4):
+                if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \
+                        (loc1+dy) >= self.voxel.heightmap.shape[1]:
+                    height = max(height, minheight)
+                else:
+                    height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2)
+        return height
+
+    def filtfilt(self, height_history, decay=0.2):
+        # Filtfilt
+        height_history2 = []
+        maxstep = len(height_history)
+        prev_height = height_history[0]
+        for i in range(maxstep):
+            prev_height = prev_height - decay
+            if prev_height < height_history[i]:
+                prev_height = height_history[i]
+            height_history2.append(prev_height)
+        prev_height = height_history[-1]
+        for i in range(maxstep-1, -1, -1):
+            prev_height = prev_height - decay
+            if prev_height < height_history[i]:
+                prev_height = height_history[i]
+            height_history2[i] = max(prev_height, height_history2[i])
+        return height_history2
+
+    def __len__(self):
+        return len(self.camera_poses)
+
+    def __getitem__(self, idx):
+        return self.camera_poses[idx]
+
+
+class TourCameraController:
+    def __init__(self, voxel, maxstep=128):
+        self.voxel = voxel
+        self.maxstep = maxstep
+        self.camera_poses = []  # ori, dir, up, f
+        circle = torch.linspace(0, 2*np.pi, steps=maxstep//4)
+        size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
+        # Shrink the circle a bit
+        shift = size * 0.2
+        size = size * 0.8
+
+        for i in range(maxstep//4):
+            farpoint = torch.tensor([
+                70,
+                torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+                torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+            farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+            nearpoint = torch.tensor([
+                60,
+                torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift,
+                torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift])
+            cam_ori = self.voxel.world2local(farpoint)
+            cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+            cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+            cam_f = 0.5/np.tan(np.deg2rad(73/2))  # about 24mm fov
+
+            self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+        zoom = torch.linspace(1.0, 0.25, steps=maxstep//4)
+        for i in range(maxstep//4):
+            farpoint = torch.tensor([
+                90,
+                torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+                torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+            farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+            nearpoint = torch.tensor([
+                60,
+                torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift,
+                torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift])
+            cam_ori = self.voxel.world2local(farpoint)
+            cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+            cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+            cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i])  # about 24mm fov
+
+            self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+        move = torch.linspace(1.0, 0.2, steps=maxstep//4)
+        for i in range(maxstep//4):
+            farpoint = torch.tensor([
+                90,
+                torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+            farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+            nearpoint = torch.tensor([
+                60,
+                torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+                torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+            cam_ori = self.voxel.world2local(farpoint)
+            cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+            cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+            cam_f = 0.5/np.tan(np.deg2rad(73/2))  # about 24mm fov
+
+            self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+        lift = torch.linspace(0.0, 200.0, steps=maxstep//4)
+        zoom = torch.linspace(0.6, 1.2, steps=maxstep//4)
+        for i in range(maxstep//4):
+            farpoint = torch.tensor([
+                80+lift[i],
+                torch.sin(circle[i])*size*0.2 + voxel.voxel_t.size(1)/2 + shift,
+                torch.cos(circle[i])*size*0.2 + voxel.voxel_t.size(2)/2 + shift])
+
+            farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+            nearpoint = torch.tensor([
+                60,
+                torch.sin(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift,
+                torch.cos(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift])
+            cam_ori = self.voxel.world2local(farpoint)
+            cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+            cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+            cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i])  # about 24mm fov
+
+            self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+    def _get_height(self, loc0, loc1, minheight):
+        loc0 = int(loc0)
+        loc1 = int(loc1)
+        height = minheight
+        for dx in range(-3, 4):
+            for dy in range(-3, 4):
+                if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \
+                        (loc1+dy) >= self.voxel.heightmap.shape[1]:
+                    height = max(height, minheight)
+                else:
+                    height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2)
+        return height
+
+    def __len__(self):
+        return len(self.camera_poses)
+
+    def __getitem__(self, idx):
+        return self.camera_poses[idx]
+
+
+def rand_camera_pose_birdseye(voxel, border=128):
+    r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up
+    Assuming [Y X Z] coordinate. Y is negative gravity direction.
+    The camera pose is converted into the voxel coordinate system so that it can be used directly for rendering
+    1. Uniformly sample a point on the upper hemisphere of a unit sphere, as cam_ori.
+    2. Set cam_dir to be from cam_ori to the origin
+    3. cam_up is always pointing towards sky
+    4. move cam_ori to random place according to voxel size
+    """
+    cam_dir = torch.randn(3, dtype=torch.float32)
+    cam_dir = cam_dir / torch.sqrt(torch.sum(cam_dir*cam_dir))
+    cam_dir[0] = -torch.abs(cam_dir[0])
+    cam_up = torch.tensor([1, 0, 0], dtype=torch.float32)
+
+    # generate camera lookat target
+    r = np.random.rand(2)
+    r[0] *= voxel.voxel_t.size(1)-border-border
+    r[1] *= voxel.voxel_t.size(2)-border-border
+    r = r + border
+    y = voxel.heightmap[int(r[0]+0.5), int(r[1]+0.5)] + (np.random.rand(1)-0.5) * 5
+    cam_target = torch.tensor([y, r[0], r[1]], dtype=torch.float32)
+    cam_ori = cam_target - cam_dir * (np.random.rand(1).item() * 100)
+    cam_ori[0] = max(voxel.heightmap[int(cam_ori[1]+0.5), int(cam_ori[2]+0.5)]+2, cam_ori[0])
+    # Translate to voxel coordinate
+    cam_ori = voxel.world2local(cam_ori)
+    cam_dir = voxel.world2local(cam_dir, is_vec=True)
+    cam_up = voxel.world2local(cam_up, is_vec=True)
+
+    return cam_ori, cam_dir, cam_up
+
+
+def get_neighbor_height(heightmap, loc0, loc1, minheight, neighbor_size=7):
+    loc0 = int(loc0)
+    loc1 = int(loc1)
+    height = 0
+    for dx in range(-neighbor_size//2, neighbor_size//2+1):
+        for dy in range(-neighbor_size//2, neighbor_size//2+1):
+            if (loc0+dx) < 0 or (loc0+dx) >= heightmap.shape[0] or (loc1+dy) < 0 or (loc1+dy) >= heightmap.shape[1]:
+                height = max(height, minheight)
+            else:
+                height = max(minheight, heightmap[loc0+dx, loc1+dy] + 2)
+    return height
+
+
+def rand_camera_pose_firstperson(voxel, border=128):
+    r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up
+    """
+    r = np.random.rand(5)
+    r[0] *= voxel.voxel_t.size(1)-border-border
+    r[1] *= voxel.voxel_t.size(2)-border-border
+    r[0] = r[0] + border
+    r[1] = r[1] + border
+
+    y = get_neighbor_height(voxel.heightmap, r[0], r[1], 0) + np.random.rand(1) * 15
+
+    cam_ori = torch.tensor([y, r[0], r[1]], dtype=torch.float32)
+
+    rand_ang_h = r[2] * 2 * np.pi
+    cam_target = torch.tensor([0, cam_ori[1]+np.sin(rand_ang_h)*border*r[4], cam_ori[2] +
+                              np.cos(rand_ang_h)*border*r[4]], dtype=torch.float32)
+    cam_target[0] = get_neighbor_height(voxel.heightmap, cam_target[1],
+                                        cam_target[2], 0, neighbor_size=1) - 2 + r[3] * 10
+
+    cam_dir = cam_target - cam_ori
+
+    cam_up = torch.tensor([1, 0, 0], dtype=torch.float32)
+
+    cam_ori = voxel.world2local(cam_ori)
+    cam_dir = voxel.world2local(cam_dir, is_vec=True)
+    cam_up = voxel.world2local(cam_up, is_vec=True)
+
+    return cam_ori, cam_dir, cam_up
+
+
+def rand_camera_pose_thridperson(voxel, border=96):
+    r = torch.rand(2)
+    r[0] *= voxel.voxel_t.size(1)
+    r[1] *= voxel.voxel_t.size(2)
+    rand_height = 60 + torch.rand(1) * 40
+    rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5)
+    farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+    r = torch.rand(2)
+    r[0] *= voxel.voxel_t.size(1) - border - border
+    r[1] *= voxel.voxel_t.size(2) - border - border
+    r[0] = r[0] + border
+    r[1] = r[1] + border
+    rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5
+    nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+    cam_ori = voxel.world2local(farpoint)
+    cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
+    cam_up = voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+
+    return cam_ori, cam_dir, cam_up
+
+
+def rand_camera_pose_thridperson2(voxel, border=48):
+    r = torch.rand(2)
+    r[0] *= voxel.voxel_t.size(1) - border - border
+    r[1] *= voxel.voxel_t.size(2) - border - border
+    r[0] = r[0] + border
+    r[1] = r[1] + border
+    rand_height = 60 + torch.rand(1) * 40
+    rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5)
+    farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+    r = torch.rand(2)
+    r[0] *= voxel.voxel_t.size(1) - border - border
+    r[1] *= voxel.voxel_t.size(2) - border - border
+    r[0] = r[0] + border
+    r[1] = r[1] + border
+    rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5
+    nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+    # Random Up vector (tilt a little bit)
+    # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
+    up = torch.randn(3) * 0.02
+    up[0] = 1.0
+    up = up / up.norm(p=2)
+    cam_ori = voxel.world2local(farpoint)
+    cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
+    cam_up = voxel.world2local(up, is_vec=True)
+
+    return cam_ori, cam_dir, cam_up
+
+
+def rand_camera_pose_thridperson3(voxel, border=64):
+    r"""Attempting to solve the camera too close to wall problem and the lack of aerial poses."""
+    r = torch.rand(2)
+    r[0] *= voxel.voxel_t.size(1) - border - border
+    r[1] *= voxel.voxel_t.size(2) - border - border
+    r[0] = r[0] + border
+    r[1] = r[1] + border
+    rand_height = 60 + torch.rand(1) * 40
+    if torch.rand(1) > 0.8:
+        rand_height = 60 + torch.rand(1) * 60
+    rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=7)
+    farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+    r = torch.rand(2)
+    r[0] *= voxel.voxel_t.size(1) - border - border
+    r[1] *= voxel.voxel_t.size(2) - border - border
+    r[0] = r[0] + border
+    r[1] = r[1] + border
+    rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=3) - 5
+    nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+    # Random Up vector (tilt a little bit)
+    # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
+    up = torch.randn(3) * 0.02
+    up[0] = 1.0
+    up = up / up.norm(p=2)
+    # print(up)
+    cam_ori = voxel.world2local(farpoint)
+    cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
+    cam_up = voxel.world2local(up, is_vec=True)
+
+    return cam_ori, cam_dir, cam_up
+
+
+def rand_camera_pose_tour(voxel):
+    size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
+    center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2]
+
+    rnd = torch.rand(8)
+
+    rnd_deg = torch.rand(1) * 2 * np.pi
+    far_radius = rnd[0]*0.8+0.2
+    far_height = rnd[1]*30 + 60
+    farpoint = torch.tensor([
+        far_height,
+        torch.sin(rnd_deg)*size*far_radius + center[0],
+        torch.cos(rnd_deg)*size*far_radius + center[1]])
+
+    farpoint[0] = get_neighbor_height(voxel.heightmap, farpoint[1], farpoint[2], farpoint[0], neighbor_size=7)
+
+    near_radius = far_radius * rnd[2]
+    near_shift_rad = np.pi*(rnd[3]-0.5)
+    near_height = 60 + rnd[4] * 10
+    nearpoint = torch.tensor([
+        near_height,
+        torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0],
+        torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]])
+
+    # Random Up vector (tilt a little bit)
+    # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
+    up = torch.randn(3) * 0.02
+    up[0] = 1.0
+    up = up / up.norm(p=2)
+    cam_ori = voxel.world2local(farpoint)
+    cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
+    cam_up = voxel.world2local(up, is_vec=True)
+    cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25))  # about 24mm fov
+
+    return cam_ori, cam_dir, cam_up, cam_f
+
+# Look from center to outward
+
+
+def rand_camera_pose_insideout(voxel):
+    size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
+    center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2]
+
+    rnd = torch.rand(8)
+
+    rnd_deg = torch.rand(1) * 2 * np.pi
+    far_radius = rnd[0]*0.8+0.2
+    far_height = rnd[1]*10 + 60
+    farpoint = torch.tensor([
+        far_height,
+        torch.sin(rnd_deg)*size*far_radius + center[0],
+        torch.cos(rnd_deg)*size*far_radius + center[1]])
+
+    near_radius = far_radius * rnd[2]
+    near_shift_rad = np.pi*(rnd[3]-0.5)
+    near_height = 60 + rnd[4] * 30
+    nearpoint = torch.tensor([
+        near_height,
+        torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0],
+        torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]])
+
+    nearpoint[0] = get_neighbor_height(voxel.heightmap, nearpoint[1], nearpoint[2], nearpoint[0], neighbor_size=7)
+
+    # Random Up vector (tilt a little bit)
+    # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
+    up = torch.randn(3) * 0.02
+    up[0] = 1.0
+    up = up / up.norm(p=2)
+    cam_ori = voxel.world2local(nearpoint)
+    cam_dir = voxel.world2local(farpoint-nearpoint, is_vec=True)
+    cam_up = voxel.world2local(up, is_vec=True)
+    cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25))  # about 24mm fov
+
+    return cam_ori, cam_dir, cam_up, cam_f
diff --git a/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv b/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv
new file mode 100644
index 0000000000000000000000000000000000000000..ba061b7a5bda98899c8b1f653d5f204fccfa38e4
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv
@@ -0,0 +1,182 @@
+person,#00AC0D
+bicycle,#012F47
+car,#0275B8
+motorcycle,#03C098
+airplane,#04434F
+bus,#05FB29
+train,#06C312
+truck,#076728
+boat,#0809B6
+traffic-light,#09D3CF
+fire-hydrant,#0A150B
+street-sign,#0BF2A6
+stop-sign,#0C246F
+parking-meter,#0D575D
+bench,#0E46F9
+bird,#0FD881
+cat,#1058DF
+dog,#118C76
+horse,#123A2C
+sheep,#13C1D8
+cow,#14E67D
+elephant,#152718
+bear,#165743
+zebra,#17AED2
+giraffe,#1858EF
+hat,#195103
+backpack,#1AA5EA
+umbrella,#1B19CC
+shoe,#1C4DE6
+eye-glasses,#1D4823
+handbag,#1E09D6
+tie,#1F94FE
+suitcase,#2073BD
+frisbee,#21D0C5
+skis,#22F3D7
+snowboard,#23C52B
+sports-ball,#24FE20
+kite,#254F0B
+baseball-bat,#26AF68
+baseball-glove,#27C0D4
+skateboard,#28528A
+surfboard,#2963B6
+tennis-racket,#2AD8EB
+bottle,#2BB1A5
+plate,#2CF37D
+wine-glass,#2D1D9C
+cup,#2E936F
+fork,#2F93E8
+knife,#308E02
+spoon,#31A71B
+bowl,#3220D3
+banana,#33C1D9
+apple,#340997
+sandwich,#35B935
+orange,#367F33
+broccoli,#3720AE
+carrot,#381F94
+hot-dog,#39CAB5
+pizza,#3AF41D
+donut,#3B9743
+cake,#3CA323
+chair,#3DFE27
+couch,#3ECB89
+potted-plant,#3F7249
+bed,#40B729
+mirror,#411C97
+dining-table,#422283
+window,#43802E
+desk,#4480DA
+toilet,#45A4B2
+door,#46356C
+tv,#478503
+laptop,#48261F
+mouse,#49E809
+remote,#4AF48A
+keyboard,#4B111B
+cell-phone,#4C4FAD
+microwave,#4D84C7
+oven,#4E69A7
+toaster,#4F2A3D
+sink,#50BA55
+refrigerator,#511F61
+blender,#52782C
+book,#530122
+clock,#5441A2
+vase,#55E758
+scissors,#56A921
+teddy-bear,#573985
+hair-drier,#5823E8
+toothbrush,#5966FF
+hair-brush,#5A7724
+banner,#5B0B00
+blanket,#5CAECB
+branch,#5D5222
+bridge,#5E5BC5
+building-other,#5F807E
+bush,#606E32
+cabinet,#6163FE
+cage,#623550
+cardboard,#638CBE
+carpet,#647988
+ceiling-other,#65AABD
+ceiling-tile,#665481
+cloth,#67CBD1
+clothes,#684470
+clouds,#696969
+counter,#6AC478
+cupboard,#6B2F5B
+curtain,#6C7FA8
+desk-stuff,#6DF474
+dirt,#6E6E28
+door-stuff,#6FCCB0
+fence,#706419
+floor-marble,#71B443
+floor-other,#72E867
+floor-stone,#734EFC
+floor-tile,#748F23
+floor-wood,#759472
+flower,#760000
+fog,#77BA1D
+food-other,#7817F1
+fruit,#79CF21
+furniture-other,#7A8D92
+grass,#7BC800
+gravel,#7C32C8
+ground-other,#7D3054
+hill,#7EC864
+house,#7F4502
+leaves,#80A945
+light,#81A365
+mat,#82C08C
+metal,#835F2C
+mirror-stuff,#84C575
+moss,#855EFD
+mountain,#869664
+mud,#87716F
+napkin,#88B25B
+net,#892455
+paper,#8AA2A7
+pavement,#8B3027
+pillow,#8C5DCB
+plant,#8DE61E
+plastic,#8E629E
+platform,#8F2A91
+playingfield,#90CDC6
+railing,#9170C7
+railroad,#92E712
+river,#9364C8
+road,#946E28
+rock,#956432
+roof,#9600B1
+rug,#978A29
+salad,#98725D
+sand,#999900
+sea,#9AC6DA
+shelf,#9B7FC9
+sky,#9CEEDD
+skyscraper,#9DBBF2
+snow,#9E9EAA
+solid-other,#9F79DB
+stairs,#A06249
+stone,#A1A164
+straw,#A2A3EB
+structural,#A3DED1
+table,#A47B69
+tent,#A5C3BA
+textile-other,#A65280
+towel,#A7AED6
+tree,#A8C832
+vegetable,#A99410
+wall-brick,#AAD16A
+wall-concrete,#AB32A4
+wall-other,#AC9B5E
+wall-panel,#AD0E18
+wall-stone,#AE2974
+wall-tile,#AF3ABF
+wall-wood,#B0C1C3
+water,#B1C8FF
+waterdrops,#B20A88
+window-blind,#B356B8
+window-other,#B42B5B
+wood,#B57B00
diff --git a/imaginaire/model_utils/gancraft/gaugan_reduction.csv b/imaginaire/model_utils/gancraft/gaugan_reduction.csv
new file mode 100644
index 0000000000000000000000000000000000000000..a49ad38c2f8bddd04a12a08a09b1bbdab5944fc2
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/gaugan_reduction.csv
@@ -0,0 +1,182 @@
+person,ignore
+bicycle,ignore
+car,ignore
+motorcycle,ignore
+airplane,ignore
+bus,ignore
+train,ignore
+truck,ignore
+boat,ignore
+traffic-light,ignore
+fire-hydrant,ignore
+street-sign,ignore
+stop-sign,ignore
+parking-meter,ignore
+bench,ignore
+bird,ignore
+cat,ignore
+dog,ignore
+horse,ignore
+sheep,ignore
+cow,ignore
+elephant,ignore
+bear,ignore
+zebra,ignore
+giraffe,ignore
+hat,ignore
+backpack,ignore
+umbrella,ignore
+shoe,ignore
+eye-glasses,ignore
+handbag,ignore
+tie,ignore
+suitcase,ignore
+frisbee,ignore
+skis,ignore
+snowboard,ignore
+sports-ball,ignore
+kite,ignore
+baseball-bat,ignore
+baseball-glove,ignore
+skateboard,ignore
+surfboard,ignore
+tennis-racket,ignore
+bottle,ignore
+plate,ignore
+wine-glass,ignore
+cup,ignore
+fork,ignore
+knife,ignore
+spoon,ignore
+bowl,ignore
+banana,ignore
+apple,ignore
+sandwich,ignore
+orange,ignore
+broccoli,ignore
+carrot,ignore
+hot-dog,ignore
+pizza,ignore
+donut,ignore
+cake,ignore
+chair,ignore
+couch,ignore
+potted-plant,ignore
+bed,ignore
+mirror,ignore
+dining-table,ignore
+window,ignore
+desk,ignore
+toilet,ignore
+door,ignore
+tv,ignore
+laptop,ignore
+mouse,ignore
+remote,ignore
+keyboard,ignore
+cell-phone,ignore
+microwave,ignore
+oven,ignore
+toaster,ignore
+sink,ignore
+refrigerator,ignore
+blender,ignore
+book,ignore
+clock,ignore
+vase,ignore
+scissors,ignore
+teddy-bear,ignore
+hair-drier,ignore
+toothbrush,ignore
+hair-brush,ignore
+banner,ignore
+blanket,ignore
+branch,tree
+bridge,ignore
+building-other,ignore
+bush,tree
+cabinet,ignore
+cage,ignore
+cardboard,ignore
+carpet,ignore
+ceiling-other,ignore
+ceiling-tile,ignore
+cloth,ignore
+clothes,ignore
+clouds,sky
+counter,ignore
+cupboard,ignore
+curtain,ignore
+desk-stuff,ignore
+dirt,dirt
+door-stuff,ignore
+fence,ignore
+floor-marble,ignore
+floor-other,ignore
+floor-stone,ignore
+floor-tile,ignore
+floor-wood,ignore
+flower,flower
+fog,sky
+food-other,ignore
+fruit,ignore
+furniture-other,ignore
+grass,grass
+gravel,gravel
+ground-other,ignore
+hill,grass
+house,ignore
+leaves,tree
+light,ignore
+mat,ignore
+metal,ignore
+mirror-stuff,ignore
+moss,grass
+mountain,grass
+mud,dirt
+napkin,ignore
+net,ignore
+paper,ignore
+pavement,ignore
+pillow,ignore
+plant,flower
+plastic,ignore
+platform,ignore
+playingfield,ignore
+railing,ignore
+railroad,ignore
+river,water
+road,ignore
+rock,rock
+roof,ignore
+rug,ignore
+salad,ignore
+sand,sand
+sea,water
+shelf,ignore
+sky,sky
+skyscraper,ignore
+snow,snow
+solid-other,ignore
+stairs,ignore
+stone,stone
+straw,grass
+structural,ignore
+table,ignore
+tent,ignore
+textile-other,ignore
+towel,ignore
+tree,tree
+vegetable,ignore
+wall-brick,ignore
+wall-concrete,ignore
+wall-other,ignore
+wall-panel,ignore
+wall-stone,ignore
+wall-tile,ignore
+wall-wood,ignore
+water,water
+waterdrops,ignore
+window-blind,ignore
+window-other,ignore
+wood,ignore
diff --git a/imaginaire/model_utils/gancraft/id2name_gg.csv b/imaginaire/model_utils/gancraft/id2name_gg.csv
new file mode 100644
index 0000000000000000000000000000000000000000..bb52afe4132cdae36494c08dab6ac4982f572386
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/id2name_gg.csv
@@ -0,0 +1,680 @@
+0,air,0,sky
+1,stone,7368816,stone
+2,granite,7368816,rock
+3,polished_granite,7368816,rock
+4,diorite,7368816,rock
+5,polished_diorite,7368816,rock
+6,andesite,7368816,rock
+7,polished_andesite,7368816,rock
+8,grass_block,8368696,grass
+9,dirt,9923917,dirt
+10,coarse_dirt,9923917,dirt
+11,podzol,9923917,dirt
+12,cobblestone,7368816,stone
+13,oak_planks,9402184,wood
+14,spruce_planks,9402184,wood
+15,birch_planks,9402184,wood
+16,jungle_planks,9402184,wood
+17,acacia_planks,9402184,wood
+18,dark_oak_planks,9402184,wood
+19,oak_sapling,31744,plant
+20,spruce_sapling,31744,plant
+21,birch_sapling,31744,plant
+22,jungle_sapling,31744,plant
+23,acacia_sapling,31744,plant
+24,dark_oak_sapling,31744,plant
+25,bedrock,7368816,rock
+26,water,4210943,water
+27,lava,16711680,
+28,sand,16247203,sand
+29,red_sand,16247203,sand
+30,gravel,16247203,gravel
+31,gold_ore,7368816,rock
+32,iron_ore,7368816,rock
+33,coal_ore,7368816,rock
+34,oak_log,9402184,tree
+35,spruce_log,9402184,tree
+36,birch_log,9402184,tree
+37,jungle_log,9402184,tree
+38,acacia_log,9402184,tree
+39,dark_oak_log,9402184,tree
+40,stripped_spruce_log,9402184,wood
+41,stripped_birch_log,9402184,wood
+42,stripped_jungle_log,9402184,wood
+43,stripped_acacia_log,9402184,wood
+44,stripped_dark_oak_log,9402184,wood
+45,stripped_oak_log,9402184,wood
+46,oak_wood,9402184,wood
+47,spruce_wood,9402184,wood
+48,birch_wood,9402184,wood
+49,jungle_wood,9402184,wood
+50,acacia_wood,9402184,wood
+51,dark_oak_wood,9402184,wood
+52,stripped_oak_wood,9402184,wood
+53,stripped_spruce_wood,9402184,wood
+54,stripped_birch_wood,9402184,wood
+55,stripped_jungle_wood,9402184,wood
+56,stripped_acacia_wood,9402184,wood
+57,stripped_dark_oak_wood,9402184,wood
+58,oak_leaves,31744,tree
+59,spruce_leaves,31744,tree
+60,birch_leaves,31744,tree
+61,jungle_leaves,31744,tree
+62,acacia_leaves,31744,tree
+63,dark_oak_leaves,31744,tree
+64,sponge,15066419,
+65,wet_sponge,15066419,
+66,glass,0,
+67,lapis_ore,7368816,
+68,lapis_block,10987431,
+69,dispenser,7368816,
+70,sandstone,7368816,sand
+71,chiseled_sandstone,7368816,sand
+72,cut_sandstone,7368816,sand
+73,note_block,9402184,
+74,white_bed,13092807,
+75,orange_bed,13092807,
+76,magenta_bed,13092807,
+77,light_blue_bed,13092807,
+78,yellow_bed,13092807,
+79,lime_bed,13092807,
+80,pink_bed,13092807,
+81,gray_bed,13092807,
+82,light_gray_bed,13092807,
+83,cyan_bed,13092807,
+84,purple_bed,13092807,
+85,blue_bed,13092807,
+86,brown_bed,13092807,
+87,green_bed,13092807,
+88,red_bed,13092807,
+89,black_bed,13092807,
+90,powered_rail,0,
+91,detector_rail,0,
+92,sticky_piston,7368816,
+93,cobweb,13092807,
+94,grass,31744,grass
+95,fern,31744,grass
+96,dead_bush,31744,grass
+97,seagrass,4210943,water
+98,tall_seagrass,4210943,water
+99,piston,7368816,
+100,piston_head,7368816,
+101,white_wool,13092807,
+102,orange_wool,13092807,
+103,magenta_wool,13092807,
+104,light_blue_wool,13092807,
+105,yellow_wool,13092807,
+106,lime_wool,13092807,
+107,pink_wool,13092807,
+108,gray_wool,13092807,
+109,light_gray_wool,13092807,
+110,cyan_wool,13092807,
+111,purple_wool,13092807,
+112,blue_wool,13092807,
+113,brown_wool,13092807,
+114,green_wool,13092807,
+115,red_wool,13092807,
+116,black_wool,13092807,
+117,moving_piston,7368816,
+118,dandelion,31744,flower
+119,poppy,31744,flower
+120,blue_orchid,31744,flower
+121,allium,31744,flower
+122,azure_bluet,31744,flower
+123,red_tulip,31744,flower
+124,orange_tulip,31744,flower
+125,white_tulip,31744,flower
+126,pink_tulip,31744,flower
+127,oxeye_daisy,31744,flower
+128,cornflower,31744,flower
+129,wither_rose,31744,flower
+130,lily_of_the_valley,31744,flower
+131,brown_mushroom,31744,flower
+132,red_mushroom,31744,flower
+133,gold_block,10987431,
+134,iron_block,10987431,
+135,bricks,7368816,
+136,tnt,16711680,
+137,bookshelf,9402184,
+138,mossy_cobblestone,7368816,
+139,obsidian,7368816,
+140,torch,0,
+141,wall_torch,0,
+142,fire,0,
+143,spawner,7368816,
+144,oak_stairs,9402184,
+145,chest,9402184,
+146,redstone_wire,0,
+147,diamond_ore,7368816,
+148,diamond_block,10987431,
+149,crafting_table,9402184,
+150,wheat,31744,
+151,farmland,9923917,
+152,furnace,7368816,
+153,oak_sign,9402184,
+154,spruce_sign,9402184,
+155,birch_sign,9402184,
+156,acacia_sign,9402184,
+157,jungle_sign,9402184,
+158,dark_oak_sign,9402184,
+159,oak_door,9402184,
+160,ladder,0,
+161,rail,0,
+162,cobblestone_stairs,7368816,
+163,oak_wall_sign,9402184,
+164,spruce_wall_sign,9402184,
+165,birch_wall_sign,9402184,
+166,acacia_wall_sign,9402184,
+167,jungle_wall_sign,9402184,
+168,dark_oak_wall_sign,9402184,
+169,lever,0,
+170,stone_pressure_plate,7368816,
+171,iron_door,10987431,
+172,oak_pressure_plate,9402184,
+173,spruce_pressure_plate,9402184,
+174,birch_pressure_plate,9402184,
+175,jungle_pressure_plate,9402184,
+176,acacia_pressure_plate,9402184,
+177,dark_oak_pressure_plate,9402184,
+178,redstone_ore,7368816,
+179,redstone_torch,0,
+180,redstone_wall_torch,0,
+181,stone_button,0,
+182,snow,16777215,snow
+183,ice,10526975,snow
+184,snow_block,16777215,snow
+185,cactus,31744,plant
+186,clay,10791096,
+187,sugar_cane,31744,plant
+188,jukebox,9402184,
+189,oak_fence,9402184,
+190,pumpkin,31744,
+191,netherrack,7368816,
+192,soul_sand,16247203,
+193,glowstone,0,
+194,nether_portal,0,
+195,carved_pumpkin,31744,
+196,jack_o_lantern,31744,
+197,cake,0,
+198,repeater,0,
+199,white_stained_glass,0,
+200,orange_stained_glass,0,
+201,magenta_stained_glass,0,
+202,light_blue_stained_glass,0,
+203,yellow_stained_glass,0,
+204,lime_stained_glass,0,
+205,pink_stained_glass,0,
+206,gray_stained_glass,0,
+207,light_gray_stained_glass,0,
+208,cyan_stained_glass,0,
+209,purple_stained_glass,0,
+210,blue_stained_glass,0,
+211,brown_stained_glass,0,
+212,green_stained_glass,0,
+213,red_stained_glass,0,
+214,black_stained_glass,0,
+215,oak_trapdoor,9402184,
+216,spruce_trapdoor,9402184,
+217,birch_trapdoor,9402184,
+218,jungle_trapdoor,9402184,
+219,acacia_trapdoor,9402184,
+220,dark_oak_trapdoor,9402184,
+221,stone_bricks,7368816,
+222,mossy_stone_bricks,7368816,
+223,cracked_stone_bricks,7368816,
+224,chiseled_stone_bricks,7368816,
+225,infested_stone,10791096,
+226,infested_cobblestone,10791096,
+227,infested_stone_bricks,10791096,
+228,infested_mossy_stone_bricks,10791096,
+229,infested_cracked_stone_bricks,10791096,
+230,infested_chiseled_stone_bricks,10791096,
+231,brown_mushroom_block,9402184,tree
+232,red_mushroom_block,9402184,tree
+233,mushroom_stem,9402184,tree
+234,iron_bars,10987431,
+235,glass_pane,0,
+236,melon,31744,
+237,attached_pumpkin_stem,31744,
+238,attached_melon_stem,31744,
+239,pumpkin_stem,31744,
+240,melon_stem,31744,
+241,vine,31744,plant
+242,oak_fence_gate,9402184,
+243,brick_stairs,7368816,
+244,stone_brick_stairs,7368816,
+245,mycelium,8368696,
+246,lily_pad,31744,grass
+247,nether_bricks,7368816,
+248,nether_brick_fence,7368816,
+249,nether_brick_stairs,7368816,
+250,nether_wart,31744,
+251,enchanting_table,7368816,
+252,brewing_stand,10987431,
+253,cauldron,10987431,
+254,end_portal,0,
+255,end_portal_frame,7368816,
+256,end_stone,7368816,
+257,dragon_egg,31744,
+258,redstone_lamp,0,
+259,cocoa,31744,
+260,sandstone_stairs,7368816,
+261,emerald_ore,7368816,
+262,ender_chest,7368816,
+263,tripwire_hook,0,
+264,tripwire,0,
+265,emerald_block,10987431,
+266,spruce_stairs,9402184,
+267,birch_stairs,9402184,
+268,jungle_stairs,9402184,
+269,command_block,10987431,
+270,beacon,0,
+271,cobblestone_wall,7368816,
+272,mossy_cobblestone_wall,7368816,
+273,flower_pot,0,
+274,potted_oak_sapling,0,
+275,potted_spruce_sapling,0,
+276,potted_birch_sapling,0,
+277,potted_jungle_sapling,0,
+278,potted_acacia_sapling,0,
+279,potted_dark_oak_sapling,0,
+280,potted_fern,0,
+281,potted_dandelion,0,
+282,potted_poppy,0,
+283,potted_blue_orchid,0,
+284,potted_allium,0,
+285,potted_azure_bluet,0,
+286,potted_red_tulip,0,
+287,potted_orange_tulip,0,
+288,potted_white_tulip,0,
+289,potted_pink_tulip,0,
+290,potted_oxeye_daisy,0,
+291,potted_cornflower,0,
+292,potted_lily_of_the_valley,0,
+293,potted_wither_rose,0,
+294,potted_red_mushroom,0,
+295,potted_brown_mushroom,0,
+296,potted_dead_bush,0,
+297,potted_cactus,0,
+298,carrots,31744,
+299,potatoes,31744,
+300,oak_button,0,
+301,spruce_button,0,
+302,birch_button,0,
+303,jungle_button,0,
+304,acacia_button,0,
+305,dark_oak_button,0,
+306,skeleton_skull,0,
+307,skeleton_wall_skull,0,
+308,wither_skeleton_skull,0,
+309,wither_skeleton_wall_skull,0,
+310,zombie_head,0,
+311,zombie_wall_head,0,
+312,player_head,0,
+313,player_wall_head,0,
+314,creeper_head,0,
+315,creeper_wall_head,0,
+316,dragon_head,0,
+317,dragon_wall_head,0,
+318,anvil,10987431,
+319,chipped_anvil,10987431,
+320,damaged_anvil,10987431,
+321,trapped_chest,9402184,
+322,light_weighted_pressure_plate,10987431,
+323,heavy_weighted_pressure_plate,10987431,
+324,comparator,0,
+325,daylight_detector,9402184,
+326,redstone_block,10987431,
+327,nether_quartz_ore,7368816,
+328,hopper,10987431,
+329,quartz_block,7368816,
+330,chiseled_quartz_block,7368816,
+331,quartz_pillar,7368816,
+332,quartz_stairs,7368816,
+333,activator_rail,0,
+334,dropper,7368816,
+335,white_terracotta,7368816,
+336,orange_terracotta,7368816,
+337,magenta_terracotta,7368816,
+338,light_blue_terracotta,7368816,
+339,yellow_terracotta,7368816,
+340,lime_terracotta,7368816,
+341,pink_terracotta,7368816,
+342,gray_terracotta,7368816,
+343,light_gray_terracotta,7368816,
+344,cyan_terracotta,7368816,
+345,purple_terracotta,7368816,
+346,blue_terracotta,7368816,
+347,brown_terracotta,7368816,
+348,green_terracotta,7368816,
+349,red_terracotta,7368816,
+350,black_terracotta,7368816,
+351,white_stained_glass_pane,0,
+352,orange_stained_glass_pane,0,
+353,magenta_stained_glass_pane,0,
+354,light_blue_stained_glass_pane,0,
+355,yellow_stained_glass_pane,0,
+356,lime_stained_glass_pane,0,
+357,pink_stained_glass_pane,0,
+358,gray_stained_glass_pane,0,
+359,light_gray_stained_glass_pane,0,
+360,cyan_stained_glass_pane,0,
+361,purple_stained_glass_pane,0,
+362,blue_stained_glass_pane,0,
+363,brown_stained_glass_pane,0,
+364,green_stained_glass_pane,0,
+365,red_stained_glass_pane,0,
+366,black_stained_glass_pane,0,
+367,acacia_stairs,9402184,
+368,dark_oak_stairs,9402184,
+369,slime_block,10791096,
+370,barrier,0,
+371,iron_trapdoor,10987431,
+372,prismarine,7368816,
+373,prismarine_bricks,7368816,
+374,dark_prismarine,7368816,
+375,prismarine_stairs,7368816,
+376,prismarine_brick_stairs,7368816,
+377,dark_prismarine_stairs,7368816,
+378,prismarine_slab,7368816,
+379,prismarine_brick_slab,7368816,
+380,dark_prismarine_slab,7368816,
+381,sea_lantern,0,
+382,hay_block,8368696,
+383,white_carpet,13092807,
+384,orange_carpet,13092807,
+385,magenta_carpet,13092807,
+386,light_blue_carpet,13092807,
+387,yellow_carpet,13092807,
+388,lime_carpet,13092807,
+389,pink_carpet,13092807,
+390,gray_carpet,13092807,
+391,light_gray_carpet,13092807,
+392,cyan_carpet,13092807,
+393,purple_carpet,13092807,
+394,blue_carpet,13092807,
+395,brown_carpet,13092807,
+396,green_carpet,13092807,
+397,red_carpet,13092807,
+398,black_carpet,13092807,
+399,terracotta,7368816,
+400,coal_block,7368816,
+401,packed_ice,10526975,
+402,sunflower,31744,flower
+403,lilac,31744,flower
+404,rose_bush,31744,flower
+405,peony,31744,flower
+406,tall_grass,31744,plant
+407,large_fern,31744,plant
+408,white_banner,9402184,
+409,orange_banner,9402184,
+410,magenta_banner,9402184,
+411,light_blue_banner,9402184,
+412,yellow_banner,9402184,
+413,lime_banner,9402184,
+414,pink_banner,9402184,
+415,gray_banner,9402184,
+416,light_gray_banner,9402184,
+417,cyan_banner,9402184,
+418,purple_banner,9402184,
+419,blue_banner,9402184,
+420,brown_banner,9402184,
+421,green_banner,9402184,
+422,red_banner,9402184,
+423,black_banner,9402184,
+424,white_wall_banner,9402184,
+425,orange_wall_banner,9402184,
+426,magenta_wall_banner,9402184,
+427,light_blue_wall_banner,9402184,
+428,yellow_wall_banner,9402184,
+429,lime_wall_banner,9402184,
+430,pink_wall_banner,9402184,
+431,gray_wall_banner,9402184,
+432,light_gray_wall_banner,9402184,
+433,cyan_wall_banner,9402184,
+434,purple_wall_banner,9402184,
+435,blue_wall_banner,9402184,
+436,brown_wall_banner,9402184,
+437,green_wall_banner,9402184,
+438,red_wall_banner,9402184,
+439,black_wall_banner,9402184,
+440,red_sandstone,7368816,
+441,chiseled_red_sandstone,7368816,
+442,cut_red_sandstone,7368816,
+443,red_sandstone_stairs,7368816,
+444,oak_slab,9402184,
+445,spruce_slab,9402184,
+446,birch_slab,9402184,
+447,jungle_slab,9402184,
+448,acacia_slab,9402184,
+449,dark_oak_slab,9402184,
+450,stone_slab,7368816,
+451,smooth_stone_slab,7368816,
+452,sandstone_slab,7368816,
+453,cut_sandstone_slab,7368816,
+454,petrified_oak_slab,7368816,
+455,cobblestone_slab,7368816,
+456,brick_slab,7368816,
+457,stone_brick_slab,7368816,
+458,nether_brick_slab,7368816,
+459,quartz_slab,7368816,
+460,red_sandstone_slab,7368816,
+461,cut_red_sandstone_slab,7368816,
+462,purpur_slab,7368816,
+463,smooth_stone,7368816,
+464,smooth_sandstone,7368816,
+465,smooth_quartz,7368816,
+466,smooth_red_sandstone,7368816,
+467,spruce_fence_gate,9402184,
+468,birch_fence_gate,9402184,
+469,jungle_fence_gate,9402184,
+470,acacia_fence_gate,9402184,
+471,dark_oak_fence_gate,9402184,
+472,spruce_fence,9402184,
+473,birch_fence,9402184,
+474,jungle_fence,9402184,
+475,acacia_fence,9402184,
+476,dark_oak_fence,9402184,
+477,spruce_door,9402184,
+478,birch_door,9402184,
+479,jungle_door,9402184,
+480,acacia_door,9402184,
+481,dark_oak_door,9402184,
+482,end_rod,0,
+483,chorus_plant,31744,
+484,chorus_flower,31744,
+485,purpur_block,7368816,
+486,purpur_pillar,7368816,
+487,purpur_stairs,7368816,
+488,end_stone_bricks,7368816,
+489,beetroots,31744,
+490,grass_path,9923917,
+491,end_gateway,0,
+492,repeating_command_block,10987431,
+493,chain_command_block,10987431,
+494,frosted_ice,10526975,
+495,magma_block,7368816,
+496,nether_wart_block,8368696,
+497,red_nether_bricks,7368816,
+498,bone_block,7368816,
+499,structure_void,0,
+500,observer,7368816,
+501,shulker_box,8339378,
+502,white_shulker_box,8339378,
+503,orange_shulker_box,8339378,
+504,magenta_shulker_box,8339378,
+505,light_blue_shulker_box,8339378,
+506,yellow_shulker_box,8339378,
+507,lime_shulker_box,8339378,
+508,pink_shulker_box,8339378,
+509,gray_shulker_box,8339378,
+510,light_gray_shulker_box,8339378,
+511,cyan_shulker_box,8339378,
+512,purple_shulker_box,8339378,
+513,blue_shulker_box,8339378,
+514,brown_shulker_box,8339378,
+515,green_shulker_box,8339378,
+516,red_shulker_box,8339378,
+517,black_shulker_box,8339378,
+518,white_glazed_terracotta,7368816,
+519,orange_glazed_terracotta,7368816,
+520,magenta_glazed_terracotta,7368816,
+521,light_blue_glazed_terracotta,7368816,
+522,yellow_glazed_terracotta,7368816,
+523,lime_glazed_terracotta,7368816,
+524,pink_glazed_terracotta,7368816,
+525,gray_glazed_terracotta,7368816,
+526,light_gray_glazed_terracotta,7368816,
+527,cyan_glazed_terracotta,7368816,
+528,purple_glazed_terracotta,7368816,
+529,blue_glazed_terracotta,7368816,
+530,brown_glazed_terracotta,7368816,
+531,green_glazed_terracotta,7368816,
+532,red_glazed_terracotta,7368816,
+533,black_glazed_terracotta,7368816,
+534,white_concrete,7368816,
+535,orange_concrete,7368816,
+536,magenta_concrete,7368816,
+537,light_blue_concrete,7368816,
+538,yellow_concrete,7368816,
+539,lime_concrete,7368816,
+540,pink_concrete,7368816,
+541,gray_concrete,7368816,
+542,light_gray_concrete,7368816,
+543,cyan_concrete,7368816,
+544,purple_concrete,7368816,
+545,blue_concrete,7368816,
+546,brown_concrete,7368816,
+547,green_concrete,7368816,
+548,red_concrete,7368816,
+549,black_concrete,7368816,
+550,white_concrete_powder,16247203,
+551,orange_concrete_powder,16247203,
+552,magenta_concrete_powder,16247203,
+553,light_blue_concrete_powder,16247203,
+554,yellow_concrete_powder,16247203,
+555,lime_concrete_powder,16247203,
+556,pink_concrete_powder,16247203,
+557,gray_concrete_powder,16247203,
+558,light_gray_concrete_powder,16247203,
+559,cyan_concrete_powder,16247203,
+560,purple_concrete_powder,16247203,
+561,blue_concrete_powder,16247203,
+562,brown_concrete_powder,16247203,
+563,green_concrete_powder,16247203,
+564,red_concrete_powder,16247203,
+565,black_concrete_powder,16247203,
+566,kelp,4210943,
+567,kelp_plant,4210943,
+568,dried_kelp_block,8368696,
+569,turtle_egg,31744,
+570,dead_tube_coral_block,7368816,
+571,dead_brain_coral_block,7368816,
+572,dead_bubble_coral_block,7368816,
+573,dead_fire_coral_block,7368816,
+574,dead_horn_coral_block,7368816,
+575,tube_coral_block,7368816,
+576,brain_coral_block,7368816,
+577,bubble_coral_block,7368816,
+578,fire_coral_block,7368816,
+579,horn_coral_block,7368816,
+580,dead_tube_coral,7368816,
+581,dead_brain_coral,7368816,
+582,dead_bubble_coral,7368816,
+583,dead_fire_coral,7368816,
+584,dead_horn_coral,7368816,
+585,tube_coral,4210943,
+586,brain_coral,4210943,
+587,bubble_coral,4210943,
+588,fire_coral,4210943,
+589,horn_coral,4210943,
+590,dead_tube_coral_fan,7368816,
+591,dead_brain_coral_fan,7368816,
+592,dead_bubble_coral_fan,7368816,
+593,dead_fire_coral_fan,7368816,
+594,dead_horn_coral_fan,7368816,
+595,tube_coral_fan,4210943,
+596,brain_coral_fan,4210943,
+597,bubble_coral_fan,4210943,
+598,fire_coral_fan,4210943,
+599,horn_coral_fan,4210943,
+600,dead_tube_coral_wall_fan,7368816,
+601,dead_brain_coral_wall_fan,7368816,
+602,dead_bubble_coral_wall_fan,7368816,
+603,dead_fire_coral_wall_fan,7368816,
+604,dead_horn_coral_wall_fan,7368816,
+605,tube_coral_wall_fan,4210943,
+606,brain_coral_wall_fan,4210943,
+607,bubble_coral_wall_fan,4210943,
+608,fire_coral_wall_fan,4210943,
+609,horn_coral_wall_fan,4210943,
+610,sea_pickle,4210943,
+611,blue_ice,10526975,
+612,conduit,0,
+613,bamboo_sapling,9402184,plant
+614,bamboo,9402184,plant
+615,potted_bamboo,0,
+616,void_air,0,dirt
+617,cave_air,0,dirt
+618,bubble_column,4210943,
+619,polished_granite_stairs,7368816,
+620,smooth_red_sandstone_stairs,7368816,
+621,mossy_stone_brick_stairs,7368816,
+622,polished_diorite_stairs,7368816,
+623,mossy_cobblestone_stairs,7368816,
+624,end_stone_brick_stairs,7368816,
+625,stone_stairs,7368816,
+626,smooth_sandstone_stairs,7368816,
+627,smooth_quartz_stairs,7368816,
+628,granite_stairs,7368816,
+629,andesite_stairs,7368816,
+630,red_nether_brick_stairs,7368816,
+631,polished_andesite_stairs,7368816,
+632,diorite_stairs,7368816,
+633,polished_granite_slab,7368816,
+634,smooth_red_sandstone_slab,7368816,
+635,mossy_stone_brick_slab,7368816,
+636,polished_diorite_slab,7368816,
+637,mossy_cobblestone_slab,7368816,
+638,end_stone_brick_slab,7368816,
+639,smooth_sandstone_slab,7368816,
+640,smooth_quartz_slab,7368816,
+641,granite_slab,7368816,
+642,andesite_slab,7368816,
+643,red_nether_brick_slab,7368816,
+644,polished_andesite_slab,7368816,
+645,diorite_slab,7368816,
+646,brick_wall,7368816,
+647,prismarine_wall,7368816,
+648,red_sandstone_wall,7368816,
+649,mossy_stone_brick_wall,7368816,
+650,granite_wall,7368816,
+651,stone_brick_wall,7368816,
+652,nether_brick_wall,7368816,
+653,andesite_wall,7368816,
+654,red_nether_brick_wall,7368816,
+655,sandstone_wall,7368816,
+656,end_stone_brick_wall,7368816,
+657,diorite_wall,7368816,
+658,scaffolding,0,
+659,loom,9402184,
+660,barrel,9402184,
+661,smoker,7368816,
+662,blast_furnace,7368816,
+663,cartography_table,9402184,
+664,fletching_table,9402184,
+665,grindstone,10987431,
+666,lectern,9402184,
+667,smithing_table,9402184,
+668,stonecutter,7368816,
+669,bell,10987431,
+670,lantern,10987431,
+671,campfire,9402184,
+672,sweet_berry_bush,31744,
+673,structure_block,10987431,
+674,jigsaw,10987431,
+675,composter,9402184,
+676,bee_nest,9402184,
+677,beehive,9402184,
+678,honey_block,10791096,
+679,honeycomb_block,10791096,
diff --git a/imaginaire/model_utils/gancraft/layers.py b/imaginaire/model_utils/gancraft/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d49900b0187c575d37c59a4fa6f62fb06413ef1d
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/layers.py
@@ -0,0 +1,153 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+class AffineMod(nn.Module):
+    r"""Learning affine modulation of activation.
+
+    Args:
+        in_features (int): Number of input features.
+        style_features (int): Number of style features.
+        mod_bias (bool): Whether to modulate bias.
+    """
+
+    def __init__(self,
+                 in_features,
+                 style_features,
+                 mod_bias=True
+                 ):
+        super().__init__()
+        self.weight_alpha = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features))
+        self.bias_alpha = nn.Parameter(torch.full([in_features], 1, dtype=torch.float))  # init to 1
+        self.weight_beta = None
+        self.bias_beta = None
+        self.mod_bias = mod_bias
+        if mod_bias:
+            self.weight_beta = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features))
+            self.bias_beta = nn.Parameter(torch.full([in_features], 0, dtype=torch.float))
+
+    @staticmethod
+    def _linear_f(x, w, b):
+        w = w.to(x.dtype)
+        x_shape = x.shape
+        x = x.reshape(-1, x_shape[-1])
+        if b is not None:
+            b = b.to(x.dtype)
+            x = torch.addmm(b.unsqueeze(0), x, w.t())
+        else:
+            x = x.matmul(w.t())
+        x = x.reshape(*x_shape[:-1], -1)
+        return x
+
+    # x: B, ...   , Cin
+    # z: B, 1, 1, , Cz
+    def forward(self, x, z):
+        x_shape = x.shape
+        z_shape = z.shape
+        x = x.reshape(x_shape[0], -1, x_shape[-1])
+        z = z.reshape(z_shape[0], 1, z_shape[-1])
+
+        alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha)  # [B, ..., I]
+        x = x * alpha
+
+        if self.mod_bias:
+            beta = self._linear_f(z, self.weight_beta, self.bias_beta)  # [B, ..., I]
+            x = x + beta
+
+        x = x.reshape(*x_shape[:-1], x.shape[-1])
+        return x
+
+
+class ModLinear(nn.Module):
+    r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod).
+    Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across
+    multiple inputs.
+    Args:
+        in_features (int): Number of input features.
+        out_features (int): Number of output features.
+        style_features (int): Number of style features.
+        bias (bool): Apply additive bias before the activation function?
+        mod_bias (bool): Whether to modulate bias.
+        output_mode (bool): If True, modulate output instead of input.
+        weight_gain (float): Initialization gain
+    """
+
+    def __init__(self,
+                 in_features,
+                 out_features,
+                 style_features,
+                 bias=True,
+                 mod_bias=True,
+                 output_mode=False,
+                 weight_gain=1,
+                 bias_init=0
+                 ):
+        super().__init__()
+        weight_gain = weight_gain / np.sqrt(in_features)
+        self.weight = nn.Parameter(torch.randn([out_features, in_features]) * weight_gain)
+        self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+        self.weight_alpha = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features))
+        self.bias_alpha = nn.Parameter(torch.full([in_features], 1, dtype=torch.float))  # init to 1
+        self.weight_beta = None
+        self.bias_beta = None
+        self.mod_bias = mod_bias
+        self.output_mode = output_mode
+        if mod_bias:
+            if output_mode:
+                mod_bias_dims = out_features
+            else:
+                mod_bias_dims = in_features
+            self.weight_beta = nn.Parameter(torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features))
+            self.bias_beta = nn.Parameter(torch.full([mod_bias_dims], 0, dtype=torch.float))
+
+    @staticmethod
+    def _linear_f(x, w, b):
+        w = w.to(x.dtype)
+        x_shape = x.shape
+        x = x.reshape(-1, x_shape[-1])
+        if b is not None:
+            b = b.to(x.dtype)
+            x = torch.addmm(b.unsqueeze(0), x, w.t())
+        else:
+            x = x.matmul(w.t())
+        x = x.reshape(*x_shape[:-1], -1)
+        return x
+
+    # x: B, ...   , Cin
+    # z: B, 1, 1, , Cz
+    def forward(self, x, z):
+        x_shape = x.shape
+        z_shape = z.shape
+        x = x.reshape(x_shape[0], -1, x_shape[-1])
+        z = z.reshape(z_shape[0], 1, z_shape[-1])
+
+        alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha)  # [B, ..., I]
+        w = self.weight.to(x.dtype)  # [O I]
+        w = w.unsqueeze(0) * alpha  # [1 O I] * [B 1 I] = [B O I]
+
+        if self.mod_bias:
+            beta = self._linear_f(z, self.weight_beta, self.bias_beta)  # [B, ..., I]
+            if not self.output_mode:
+                x = x + beta
+
+        b = self.bias
+        if b is not None:
+            b = b.to(x.dtype)[None, None, :]
+        if self.mod_bias and self.output_mode:
+            if b is None:
+                b = beta
+            else:
+                b = b + beta
+
+        # [B ? I] @ [B I O] = [B ? O]
+        if b is not None:
+            x = torch.baddbmm(b, x, w.transpose(1, 2))
+        else:
+            x = x.bmm(w.transpose(1, 2))
+        x = x.reshape(*x_shape[:-1], x.shape[-1])
+        return x
diff --git a/imaginaire/model_utils/gancraft/loss.py b/imaginaire/model_utils/gancraft/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1811de5307535167f645b4ea8a889a468b41780
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/loss.py
@@ -0,0 +1,96 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GANLoss(nn.Module):
+    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
+        r"""GAN loss constructor.
+
+        Args:
+            target_real_label (float): Desired output label for the real images.
+            target_fake_label (float): Desired output label for the fake images.
+        """
+        super(GANLoss, self).__init__()
+        self.real_label = target_real_label
+        self.fake_label = target_fake_label
+        self.real_label_tensor = None
+        self.fake_label_tensor = None
+
+    def forward(self, input_x, t_real, weight=None,
+                reduce_dim=True, dis_update=True):
+        r"""GAN loss computation.
+
+        Args:
+            input_x (tensor or list of tensors): Output values.
+            t_real (boolean): Is this output value for real images.
+            reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
+            multi-resolution discriminators.
+            weight (float): Weight to scale the loss value.
+            dis_update (boolean): Updating the discriminator or the generator.
+        Returns:
+            loss (tensor): Loss value.
+        """
+        if isinstance(input_x, list):
+            loss = 0
+            for pred_i in input_x:
+                if isinstance(pred_i, list):
+                    pred_i = pred_i[-1]
+                loss_tensor = self.loss(pred_i, t_real, weight,
+                                        reduce_dim, dis_update)
+                bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
+                new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
+                loss += new_loss
+            return loss / len(input_x)
+        else:
+            return self.loss(input_x, t_real, weight, reduce_dim, dis_update)
+
+    def loss(self, input_x, t_real, weight=None,
+             reduce_dim=True, dis_update=True):
+        r"""N+1 label GAN loss computation.
+
+        Args:
+            input_x (tensor): Output values.
+            t_real (boolean): Is this output value for real images.
+            reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
+            multi-resolution discriminators.
+            weight (float): Weight to scale the loss value.
+            dis_update (boolean): Updating the discriminator or the generator.
+        Returns:
+            loss (tensor): Loss value.
+        """
+        assert reduce_dim is True
+        pred = input_x['pred'].clone()
+        label = input_x['label'].clone()
+        batch_size = pred.size(0)
+
+        # ignore label 0
+        label[:, 0, ...] = 0
+        pred[:, 0, ...] = 0
+        pred = F.log_softmax(pred, dim=1)
+        assert pred.size(1) == (label.size(1) + 1)
+        if dis_update:
+            if t_real:
+                pred_real = pred[:, :-1, :, :]
+                loss = - label * pred_real
+                loss = torch.sum(loss, dim=1, keepdim=True)
+            else:
+                pred_fake = pred[:, -1, None, :, :]  # N plus 1
+                loss = - pred_fake
+        else:
+            assert t_real, "GAN loss must be aiming for real."
+            pred_real = pred[:, :-1, :, :]
+            loss = - label * pred_real
+            loss = torch.sum(loss, dim=1, keepdim=True)
+
+        if weight is not None:
+            loss = loss * weight
+        if reduce_dim:
+            loss = torch.mean(loss)
+        else:
+            loss = loss.view(batch_size, -1).mean(dim=1)
+        return loss
diff --git a/imaginaire/model_utils/gancraft/mc_lbl_reduction.py b/imaginaire/model_utils/gancraft/mc_lbl_reduction.py
new file mode 100644
index 0000000000000000000000000000000000000000..03fec1d3b600cfd31358cf480924da5232e0104a
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/mc_lbl_reduction.py
@@ -0,0 +1,83 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+import csv
+
+
+class ReducedLabelMapper:
+    def __init__(self):
+        this_path = os.path.dirname(os.path.abspath(__file__))
+        print('[ReducedLabelMapper] Loading from {}'.format(this_path))
+
+        # Load Minecraft LUT
+        mcid2rdlbl_lut = {}
+        mcid2mclbl_lut = {}
+        with open(os.path.join(this_path, 'mc_reduction.csv'), newline='') as csvfile:
+            csvreader = csv.reader(csvfile, delimiter=',')
+            for row in csvreader:
+                mcid = int(row[0])
+                mcid2rdlbl_lut[mcid] = row[3]
+                mcid2mclbl_lut[mcid] = row[1]
+
+        # Load reduced label set
+        reduced_lbls = []
+        rdlbl2rdid = {}
+        with open(os.path.join(this_path, 'reduced_coco_lbls.csv'), newline='') as csvfile:
+            csvreader = csv.reader(csvfile, delimiter=',')
+            for idx, row in enumerate(csvreader):
+                rdlbl2rdid[row[0]] = idx
+                reduced_lbls.append(row[0])
+        print(['{}: {}'.format(rdid, rdlbl) for rdid, rdlbl in enumerate(reduced_lbls)])
+        # The first label should always be 'ignore'
+        assert reduced_lbls[0] == 'ignore'
+
+        # Generate Minecraft ID to Reduced ID LUT
+        mcid2rdid_lut = []
+        for mcid in range(len(mcid2rdlbl_lut)):
+            rdlbl = mcid2rdlbl_lut[mcid]
+            if rdlbl == '':
+                rdlbl = 'ignore'
+            rdid = rdlbl2rdid[rdlbl]
+            mcid2rdid_lut.append(rdid)
+
+        # ================= coco part ==================
+        gg_label_list = []
+        gglbl2ggid = {}
+        with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile:
+            csvreader = csv.reader(csvfile, delimiter=',')
+            for idx, row in enumerate(csvreader):
+                gg_label_list.append(row[0])
+                gglbl2ggid[row[0]] = idx
+
+        # Load coco -> reduced mapping table
+        gglbl2rdid = {}
+        with open(os.path.join(this_path, 'gaugan_reduction.csv'), newline='') as csvfile:
+            csvreader = csv.reader(csvfile, delimiter=',')
+            for idx, row in enumerate(csvreader):
+                gglbl = row[0]
+                target_rdlbl = row[1]
+                ggid = gglbl2ggid[gglbl]
+                target_rdid = rdlbl2rdid[target_rdlbl]
+                gglbl2rdid[ggid] = target_rdid
+        ggid2rdid = [gglbl2rdid[i] for i in range(len(gglbl2rdid))]
+
+        print('[ReducedLabelMapper] #Reduced Labels: {}'.format(len(reduced_lbls)))
+
+        self.mcid2rdid_lut = mcid2rdid_lut
+        self.ggid2rdid = ggid2rdid
+        self.reduced_lbls = reduced_lbls
+
+        self.ignore_id = rdlbl2rdid['ignore']
+        self.dirt_id = rdlbl2rdid['dirt']
+        self.water_id = rdlbl2rdid['water']
+
+        self.gglbl2ggid = gglbl2ggid
+
+    def gglbl2ggid(self, gglbl):
+        return self.gglbl2ggid[gglbl]
+
+
+if __name__ == '__main__':
+    mapper = ReducedLabelMapper()
diff --git a/imaginaire/model_utils/gancraft/mc_reduction.csv b/imaginaire/model_utils/gancraft/mc_reduction.csv
new file mode 100644
index 0000000000000000000000000000000000000000..254af7255d67be76b5d41cdbd162173e55ec0b9c
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/mc_reduction.csv
@@ -0,0 +1,680 @@
+0,air,0,sky
+1,stone,7368816,stone
+2,granite,7368816,rock
+3,polished_granite,7368816,rock
+4,diorite,7368816,rock
+5,polished_diorite,7368816,rock
+6,andesite,7368816,rock
+7,polished_andesite,7368816,rock
+8,grass_block,8368696,grass
+9,dirt,9923917,dirt
+10,coarse_dirt,9923917,dirt
+11,podzol,9923917,dirt
+12,cobblestone,7368816,stone
+13,oak_planks,9402184,
+14,spruce_planks,9402184,
+15,birch_planks,9402184,
+16,jungle_planks,9402184,
+17,acacia_planks,9402184,
+18,dark_oak_planks,9402184,
+19,oak_sapling,31744,grass
+20,spruce_sapling,31744,grass
+21,birch_sapling,31744,grass
+22,jungle_sapling,31744,grass
+23,acacia_sapling,31744,grass
+24,dark_oak_sapling,31744,grass
+25,bedrock,7368816,rock
+26,water,4210943,water
+27,lava,16711680,
+28,sand,16247203,sand
+29,red_sand,16247203,sand
+30,gravel,16247203,gravel
+31,gold_ore,7368816,rock
+32,iron_ore,7368816,rock
+33,coal_ore,7368816,rock
+34,oak_log,9402184,tree
+35,spruce_log,9402184,tree
+36,birch_log,9402184,tree
+37,jungle_log,9402184,tree
+38,acacia_log,9402184,tree
+39,dark_oak_log,9402184,tree
+40,stripped_spruce_log,9402184,
+41,stripped_birch_log,9402184,
+42,stripped_jungle_log,9402184,
+43,stripped_acacia_log,9402184,
+44,stripped_dark_oak_log,9402184,
+45,stripped_oak_log,9402184,
+46,oak_wood,9402184,
+47,spruce_wood,9402184,
+48,birch_wood,9402184,
+49,jungle_wood,9402184,
+50,acacia_wood,9402184,
+51,dark_oak_wood,9402184,
+52,stripped_oak_wood,9402184,
+53,stripped_spruce_wood,9402184,
+54,stripped_birch_wood,9402184,
+55,stripped_jungle_wood,9402184,
+56,stripped_acacia_wood,9402184,
+57,stripped_dark_oak_wood,9402184,
+58,oak_leaves,31744,tree
+59,spruce_leaves,31744,tree
+60,birch_leaves,31744,tree
+61,jungle_leaves,31744,tree
+62,acacia_leaves,31744,tree
+63,dark_oak_leaves,31744,tree
+64,sponge,15066419,
+65,wet_sponge,15066419,
+66,glass,0,
+67,lapis_ore,7368816,
+68,lapis_block,10987431,
+69,dispenser,7368816,
+70,sandstone,7368816,sand
+71,chiseled_sandstone,7368816,sand
+72,cut_sandstone,7368816,sand
+73,note_block,9402184,
+74,white_bed,13092807,
+75,orange_bed,13092807,
+76,magenta_bed,13092807,
+77,light_blue_bed,13092807,
+78,yellow_bed,13092807,
+79,lime_bed,13092807,
+80,pink_bed,13092807,
+81,gray_bed,13092807,
+82,light_gray_bed,13092807,
+83,cyan_bed,13092807,
+84,purple_bed,13092807,
+85,blue_bed,13092807,
+86,brown_bed,13092807,
+87,green_bed,13092807,
+88,red_bed,13092807,
+89,black_bed,13092807,
+90,powered_rail,0,
+91,detector_rail,0,
+92,sticky_piston,7368816,
+93,cobweb,13092807,
+94,grass,31744,grass
+95,fern,31744,grass
+96,dead_bush,31744,grass
+97,seagrass,4210943,water
+98,tall_seagrass,4210943,water
+99,piston,7368816,
+100,piston_head,7368816,
+101,white_wool,13092807,
+102,orange_wool,13092807,
+103,magenta_wool,13092807,
+104,light_blue_wool,13092807,
+105,yellow_wool,13092807,
+106,lime_wool,13092807,
+107,pink_wool,13092807,
+108,gray_wool,13092807,
+109,light_gray_wool,13092807,
+110,cyan_wool,13092807,
+111,purple_wool,13092807,
+112,blue_wool,13092807,
+113,brown_wool,13092807,
+114,green_wool,13092807,
+115,red_wool,13092807,
+116,black_wool,13092807,
+117,moving_piston,7368816,
+118,dandelion,31744,flower
+119,poppy,31744,flower
+120,blue_orchid,31744,flower
+121,allium,31744,flower
+122,azure_bluet,31744,flower
+123,red_tulip,31744,flower
+124,orange_tulip,31744,flower
+125,white_tulip,31744,flower
+126,pink_tulip,31744,flower
+127,oxeye_daisy,31744,flower
+128,cornflower,31744,flower
+129,wither_rose,31744,flower
+130,lily_of_the_valley,31744,flower
+131,brown_mushroom,31744,flower
+132,red_mushroom,31744,flower
+133,gold_block,10987431,
+134,iron_block,10987431,
+135,bricks,7368816,
+136,tnt,16711680,
+137,bookshelf,9402184,
+138,mossy_cobblestone,7368816,
+139,obsidian,7368816,
+140,torch,0,
+141,wall_torch,0,
+142,fire,0,
+143,spawner,7368816,
+144,oak_stairs,9402184,
+145,chest,9402184,
+146,redstone_wire,0,
+147,diamond_ore,7368816,
+148,diamond_block,10987431,
+149,crafting_table,9402184,
+150,wheat,31744,
+151,farmland,9923917,
+152,furnace,7368816,
+153,oak_sign,9402184,
+154,spruce_sign,9402184,
+155,birch_sign,9402184,
+156,acacia_sign,9402184,
+157,jungle_sign,9402184,
+158,dark_oak_sign,9402184,
+159,oak_door,9402184,
+160,ladder,0,
+161,rail,0,
+162,cobblestone_stairs,7368816,
+163,oak_wall_sign,9402184,
+164,spruce_wall_sign,9402184,
+165,birch_wall_sign,9402184,
+166,acacia_wall_sign,9402184,
+167,jungle_wall_sign,9402184,
+168,dark_oak_wall_sign,9402184,
+169,lever,0,
+170,stone_pressure_plate,7368816,
+171,iron_door,10987431,
+172,oak_pressure_plate,9402184,
+173,spruce_pressure_plate,9402184,
+174,birch_pressure_plate,9402184,
+175,jungle_pressure_plate,9402184,
+176,acacia_pressure_plate,9402184,
+177,dark_oak_pressure_plate,9402184,
+178,redstone_ore,7368816,
+179,redstone_torch,0,
+180,redstone_wall_torch,0,
+181,stone_button,0,
+182,snow,16777215,snow
+183,ice,10526975,snow
+184,snow_block,16777215,snow
+185,cactus,31744,flower
+186,clay,10791096,dirt
+187,sugar_cane,31744,flower
+188,jukebox,9402184,
+189,oak_fence,9402184,
+190,pumpkin,31744,
+191,netherrack,7368816,
+192,soul_sand,16247203,
+193,glowstone,0,
+194,nether_portal,0,
+195,carved_pumpkin,31744,
+196,jack_o_lantern,31744,
+197,cake,0,
+198,repeater,0,
+199,white_stained_glass,0,
+200,orange_stained_glass,0,
+201,magenta_stained_glass,0,
+202,light_blue_stained_glass,0,
+203,yellow_stained_glass,0,
+204,lime_stained_glass,0,
+205,pink_stained_glass,0,
+206,gray_stained_glass,0,
+207,light_gray_stained_glass,0,
+208,cyan_stained_glass,0,
+209,purple_stained_glass,0,
+210,blue_stained_glass,0,
+211,brown_stained_glass,0,
+212,green_stained_glass,0,
+213,red_stained_glass,0,
+214,black_stained_glass,0,
+215,oak_trapdoor,9402184,
+216,spruce_trapdoor,9402184,
+217,birch_trapdoor,9402184,
+218,jungle_trapdoor,9402184,
+219,acacia_trapdoor,9402184,
+220,dark_oak_trapdoor,9402184,
+221,stone_bricks,7368816,
+222,mossy_stone_bricks,7368816,
+223,cracked_stone_bricks,7368816,
+224,chiseled_stone_bricks,7368816,
+225,infested_stone,10791096,
+226,infested_cobblestone,10791096,
+227,infested_stone_bricks,10791096,
+228,infested_mossy_stone_bricks,10791096,
+229,infested_cracked_stone_bricks,10791096,
+230,infested_chiseled_stone_bricks,10791096,
+231,brown_mushroom_block,9402184,tree
+232,red_mushroom_block,9402184,tree
+233,mushroom_stem,9402184,tree
+234,iron_bars,10987431,
+235,glass_pane,0,
+236,melon,31744,
+237,attached_pumpkin_stem,31744,
+238,attached_melon_stem,31744,
+239,pumpkin_stem,31744,
+240,melon_stem,31744,
+241,vine,31744,tree
+242,oak_fence_gate,9402184,
+243,brick_stairs,7368816,
+244,stone_brick_stairs,7368816,
+245,mycelium,8368696,
+246,lily_pad,31744,grass
+247,nether_bricks,7368816,
+248,nether_brick_fence,7368816,
+249,nether_brick_stairs,7368816,
+250,nether_wart,31744,
+251,enchanting_table,7368816,
+252,brewing_stand,10987431,
+253,cauldron,10987431,
+254,end_portal,0,
+255,end_portal_frame,7368816,
+256,end_stone,7368816,
+257,dragon_egg,31744,
+258,redstone_lamp,0,
+259,cocoa,31744,
+260,sandstone_stairs,7368816,
+261,emerald_ore,7368816,
+262,ender_chest,7368816,
+263,tripwire_hook,0,
+264,tripwire,0,
+265,emerald_block,10987431,
+266,spruce_stairs,9402184,
+267,birch_stairs,9402184,
+268,jungle_stairs,9402184,
+269,command_block,10987431,
+270,beacon,0,
+271,cobblestone_wall,7368816,
+272,mossy_cobblestone_wall,7368816,
+273,flower_pot,0,
+274,potted_oak_sapling,0,
+275,potted_spruce_sapling,0,
+276,potted_birch_sapling,0,
+277,potted_jungle_sapling,0,
+278,potted_acacia_sapling,0,
+279,potted_dark_oak_sapling,0,
+280,potted_fern,0,
+281,potted_dandelion,0,
+282,potted_poppy,0,
+283,potted_blue_orchid,0,
+284,potted_allium,0,
+285,potted_azure_bluet,0,
+286,potted_red_tulip,0,
+287,potted_orange_tulip,0,
+288,potted_white_tulip,0,
+289,potted_pink_tulip,0,
+290,potted_oxeye_daisy,0,
+291,potted_cornflower,0,
+292,potted_lily_of_the_valley,0,
+293,potted_wither_rose,0,
+294,potted_red_mushroom,0,
+295,potted_brown_mushroom,0,
+296,potted_dead_bush,0,
+297,potted_cactus,0,
+298,carrots,31744,
+299,potatoes,31744,
+300,oak_button,0,
+301,spruce_button,0,
+302,birch_button,0,
+303,jungle_button,0,
+304,acacia_button,0,
+305,dark_oak_button,0,
+306,skeleton_skull,0,
+307,skeleton_wall_skull,0,
+308,wither_skeleton_skull,0,
+309,wither_skeleton_wall_skull,0,
+310,zombie_head,0,
+311,zombie_wall_head,0,
+312,player_head,0,
+313,player_wall_head,0,
+314,creeper_head,0,
+315,creeper_wall_head,0,
+316,dragon_head,0,
+317,dragon_wall_head,0,
+318,anvil,10987431,
+319,chipped_anvil,10987431,
+320,damaged_anvil,10987431,
+321,trapped_chest,9402184,
+322,light_weighted_pressure_plate,10987431,
+323,heavy_weighted_pressure_plate,10987431,
+324,comparator,0,
+325,daylight_detector,9402184,
+326,redstone_block,10987431,
+327,nether_quartz_ore,7368816,
+328,hopper,10987431,
+329,quartz_block,7368816,
+330,chiseled_quartz_block,7368816,
+331,quartz_pillar,7368816,
+332,quartz_stairs,7368816,
+333,activator_rail,0,
+334,dropper,7368816,
+335,white_terracotta,7368816,
+336,orange_terracotta,7368816,
+337,magenta_terracotta,7368816,
+338,light_blue_terracotta,7368816,
+339,yellow_terracotta,7368816,
+340,lime_terracotta,7368816,
+341,pink_terracotta,7368816,
+342,gray_terracotta,7368816,
+343,light_gray_terracotta,7368816,
+344,cyan_terracotta,7368816,
+345,purple_terracotta,7368816,
+346,blue_terracotta,7368816,
+347,brown_terracotta,7368816,
+348,green_terracotta,7368816,
+349,red_terracotta,7368816,
+350,black_terracotta,7368816,
+351,white_stained_glass_pane,0,
+352,orange_stained_glass_pane,0,
+353,magenta_stained_glass_pane,0,
+354,light_blue_stained_glass_pane,0,
+355,yellow_stained_glass_pane,0,
+356,lime_stained_glass_pane,0,
+357,pink_stained_glass_pane,0,
+358,gray_stained_glass_pane,0,
+359,light_gray_stained_glass_pane,0,
+360,cyan_stained_glass_pane,0,
+361,purple_stained_glass_pane,0,
+362,blue_stained_glass_pane,0,
+363,brown_stained_glass_pane,0,
+364,green_stained_glass_pane,0,
+365,red_stained_glass_pane,0,
+366,black_stained_glass_pane,0,
+367,acacia_stairs,9402184,
+368,dark_oak_stairs,9402184,
+369,slime_block,10791096,
+370,barrier,0,
+371,iron_trapdoor,10987431,
+372,prismarine,7368816,
+373,prismarine_bricks,7368816,
+374,dark_prismarine,7368816,
+375,prismarine_stairs,7368816,
+376,prismarine_brick_stairs,7368816,
+377,dark_prismarine_stairs,7368816,
+378,prismarine_slab,7368816,
+379,prismarine_brick_slab,7368816,
+380,dark_prismarine_slab,7368816,
+381,sea_lantern,0,
+382,hay_block,8368696,
+383,white_carpet,13092807,
+384,orange_carpet,13092807,
+385,magenta_carpet,13092807,
+386,light_blue_carpet,13092807,
+387,yellow_carpet,13092807,
+388,lime_carpet,13092807,
+389,pink_carpet,13092807,
+390,gray_carpet,13092807,
+391,light_gray_carpet,13092807,
+392,cyan_carpet,13092807,
+393,purple_carpet,13092807,
+394,blue_carpet,13092807,
+395,brown_carpet,13092807,
+396,green_carpet,13092807,
+397,red_carpet,13092807,
+398,black_carpet,13092807,
+399,terracotta,7368816,
+400,coal_block,7368816,
+401,packed_ice,10526975,snow
+402,sunflower,31744,flower
+403,lilac,31744,flower
+404,rose_bush,31744,flower
+405,peony,31744,flower
+406,tall_grass,31744,flower
+407,large_fern,31744,flower
+408,white_banner,9402184,
+409,orange_banner,9402184,
+410,magenta_banner,9402184,
+411,light_blue_banner,9402184,
+412,yellow_banner,9402184,
+413,lime_banner,9402184,
+414,pink_banner,9402184,
+415,gray_banner,9402184,
+416,light_gray_banner,9402184,
+417,cyan_banner,9402184,
+418,purple_banner,9402184,
+419,blue_banner,9402184,
+420,brown_banner,9402184,
+421,green_banner,9402184,
+422,red_banner,9402184,
+423,black_banner,9402184,
+424,white_wall_banner,9402184,
+425,orange_wall_banner,9402184,
+426,magenta_wall_banner,9402184,
+427,light_blue_wall_banner,9402184,
+428,yellow_wall_banner,9402184,
+429,lime_wall_banner,9402184,
+430,pink_wall_banner,9402184,
+431,gray_wall_banner,9402184,
+432,light_gray_wall_banner,9402184,
+433,cyan_wall_banner,9402184,
+434,purple_wall_banner,9402184,
+435,blue_wall_banner,9402184,
+436,brown_wall_banner,9402184,
+437,green_wall_banner,9402184,
+438,red_wall_banner,9402184,
+439,black_wall_banner,9402184,
+440,red_sandstone,7368816,
+441,chiseled_red_sandstone,7368816,
+442,cut_red_sandstone,7368816,
+443,red_sandstone_stairs,7368816,
+444,oak_slab,9402184,
+445,spruce_slab,9402184,
+446,birch_slab,9402184,
+447,jungle_slab,9402184,
+448,acacia_slab,9402184,
+449,dark_oak_slab,9402184,
+450,stone_slab,7368816,
+451,smooth_stone_slab,7368816,
+452,sandstone_slab,7368816,
+453,cut_sandstone_slab,7368816,
+454,petrified_oak_slab,7368816,
+455,cobblestone_slab,7368816,
+456,brick_slab,7368816,
+457,stone_brick_slab,7368816,
+458,nether_brick_slab,7368816,
+459,quartz_slab,7368816,
+460,red_sandstone_slab,7368816,
+461,cut_red_sandstone_slab,7368816,
+462,purpur_slab,7368816,
+463,smooth_stone,7368816,
+464,smooth_sandstone,7368816,
+465,smooth_quartz,7368816,
+466,smooth_red_sandstone,7368816,
+467,spruce_fence_gate,9402184,
+468,birch_fence_gate,9402184,
+469,jungle_fence_gate,9402184,
+470,acacia_fence_gate,9402184,
+471,dark_oak_fence_gate,9402184,
+472,spruce_fence,9402184,
+473,birch_fence,9402184,
+474,jungle_fence,9402184,
+475,acacia_fence,9402184,
+476,dark_oak_fence,9402184,
+477,spruce_door,9402184,
+478,birch_door,9402184,
+479,jungle_door,9402184,
+480,acacia_door,9402184,
+481,dark_oak_door,9402184,
+482,end_rod,0,
+483,chorus_plant,31744,
+484,chorus_flower,31744,
+485,purpur_block,7368816,
+486,purpur_pillar,7368816,
+487,purpur_stairs,7368816,
+488,end_stone_bricks,7368816,
+489,beetroots,31744,
+490,grass_path,9923917,
+491,end_gateway,0,
+492,repeating_command_block,10987431,
+493,chain_command_block,10987431,
+494,frosted_ice,10526975,snow
+495,magma_block,7368816,
+496,nether_wart_block,8368696,
+497,red_nether_bricks,7368816,
+498,bone_block,7368816,
+499,structure_void,0,
+500,observer,7368816,
+501,shulker_box,8339378,
+502,white_shulker_box,8339378,
+503,orange_shulker_box,8339378,
+504,magenta_shulker_box,8339378,
+505,light_blue_shulker_box,8339378,
+506,yellow_shulker_box,8339378,
+507,lime_shulker_box,8339378,
+508,pink_shulker_box,8339378,
+509,gray_shulker_box,8339378,
+510,light_gray_shulker_box,8339378,
+511,cyan_shulker_box,8339378,
+512,purple_shulker_box,8339378,
+513,blue_shulker_box,8339378,
+514,brown_shulker_box,8339378,
+515,green_shulker_box,8339378,
+516,red_shulker_box,8339378,
+517,black_shulker_box,8339378,
+518,white_glazed_terracotta,7368816,
+519,orange_glazed_terracotta,7368816,
+520,magenta_glazed_terracotta,7368816,
+521,light_blue_glazed_terracotta,7368816,
+522,yellow_glazed_terracotta,7368816,
+523,lime_glazed_terracotta,7368816,
+524,pink_glazed_terracotta,7368816,
+525,gray_glazed_terracotta,7368816,
+526,light_gray_glazed_terracotta,7368816,
+527,cyan_glazed_terracotta,7368816,
+528,purple_glazed_terracotta,7368816,
+529,blue_glazed_terracotta,7368816,
+530,brown_glazed_terracotta,7368816,
+531,green_glazed_terracotta,7368816,
+532,red_glazed_terracotta,7368816,
+533,black_glazed_terracotta,7368816,
+534,white_concrete,7368816,
+535,orange_concrete,7368816,
+536,magenta_concrete,7368816,
+537,light_blue_concrete,7368816,
+538,yellow_concrete,7368816,
+539,lime_concrete,7368816,
+540,pink_concrete,7368816,
+541,gray_concrete,7368816,
+542,light_gray_concrete,7368816,
+543,cyan_concrete,7368816,
+544,purple_concrete,7368816,
+545,blue_concrete,7368816,
+546,brown_concrete,7368816,
+547,green_concrete,7368816,
+548,red_concrete,7368816,
+549,black_concrete,7368816,
+550,white_concrete_powder,16247203,
+551,orange_concrete_powder,16247203,
+552,magenta_concrete_powder,16247203,
+553,light_blue_concrete_powder,16247203,
+554,yellow_concrete_powder,16247203,
+555,lime_concrete_powder,16247203,
+556,pink_concrete_powder,16247203,
+557,gray_concrete_powder,16247203,
+558,light_gray_concrete_powder,16247203,
+559,cyan_concrete_powder,16247203,
+560,purple_concrete_powder,16247203,
+561,blue_concrete_powder,16247203,
+562,brown_concrete_powder,16247203,
+563,green_concrete_powder,16247203,
+564,red_concrete_powder,16247203,
+565,black_concrete_powder,16247203,
+566,kelp,4210943,
+567,kelp_plant,4210943,
+568,dried_kelp_block,8368696,
+569,turtle_egg,31744,
+570,dead_tube_coral_block,7368816,
+571,dead_brain_coral_block,7368816,
+572,dead_bubble_coral_block,7368816,
+573,dead_fire_coral_block,7368816,
+574,dead_horn_coral_block,7368816,
+575,tube_coral_block,7368816,
+576,brain_coral_block,7368816,
+577,bubble_coral_block,7368816,
+578,fire_coral_block,7368816,
+579,horn_coral_block,7368816,
+580,dead_tube_coral,7368816,
+581,dead_brain_coral,7368816,
+582,dead_bubble_coral,7368816,
+583,dead_fire_coral,7368816,
+584,dead_horn_coral,7368816,
+585,tube_coral,4210943,
+586,brain_coral,4210943,
+587,bubble_coral,4210943,
+588,fire_coral,4210943,
+589,horn_coral,4210943,
+590,dead_tube_coral_fan,7368816,
+591,dead_brain_coral_fan,7368816,
+592,dead_bubble_coral_fan,7368816,
+593,dead_fire_coral_fan,7368816,
+594,dead_horn_coral_fan,7368816,
+595,tube_coral_fan,4210943,
+596,brain_coral_fan,4210943,
+597,bubble_coral_fan,4210943,
+598,fire_coral_fan,4210943,
+599,horn_coral_fan,4210943,
+600,dead_tube_coral_wall_fan,7368816,
+601,dead_brain_coral_wall_fan,7368816,
+602,dead_bubble_coral_wall_fan,7368816,
+603,dead_fire_coral_wall_fan,7368816,
+604,dead_horn_coral_wall_fan,7368816,
+605,tube_coral_wall_fan,4210943,
+606,brain_coral_wall_fan,4210943,
+607,bubble_coral_wall_fan,4210943,
+608,fire_coral_wall_fan,4210943,
+609,horn_coral_wall_fan,4210943,
+610,sea_pickle,4210943,
+611,blue_ice,10526975,snow
+612,conduit,0,
+613,bamboo_sapling,9402184,flower
+614,bamboo,9402184,tree
+615,potted_bamboo,0,
+616,void_air,0,dirt
+617,cave_air,0,dirt
+618,bubble_column,4210943,
+619,polished_granite_stairs,7368816,
+620,smooth_red_sandstone_stairs,7368816,
+621,mossy_stone_brick_stairs,7368816,
+622,polished_diorite_stairs,7368816,
+623,mossy_cobblestone_stairs,7368816,
+624,end_stone_brick_stairs,7368816,
+625,stone_stairs,7368816,
+626,smooth_sandstone_stairs,7368816,
+627,smooth_quartz_stairs,7368816,
+628,granite_stairs,7368816,
+629,andesite_stairs,7368816,
+630,red_nether_brick_stairs,7368816,
+631,polished_andesite_stairs,7368816,
+632,diorite_stairs,7368816,
+633,polished_granite_slab,7368816,
+634,smooth_red_sandstone_slab,7368816,
+635,mossy_stone_brick_slab,7368816,
+636,polished_diorite_slab,7368816,
+637,mossy_cobblestone_slab,7368816,
+638,end_stone_brick_slab,7368816,
+639,smooth_sandstone_slab,7368816,
+640,smooth_quartz_slab,7368816,
+641,granite_slab,7368816,
+642,andesite_slab,7368816,
+643,red_nether_brick_slab,7368816,
+644,polished_andesite_slab,7368816,
+645,diorite_slab,7368816,
+646,brick_wall,7368816,
+647,prismarine_wall,7368816,
+648,red_sandstone_wall,7368816,
+649,mossy_stone_brick_wall,7368816,
+650,granite_wall,7368816,
+651,stone_brick_wall,7368816,
+652,nether_brick_wall,7368816,
+653,andesite_wall,7368816,
+654,red_nether_brick_wall,7368816,
+655,sandstone_wall,7368816,
+656,end_stone_brick_wall,7368816,
+657,diorite_wall,7368816,
+658,scaffolding,0,
+659,loom,9402184,
+660,barrel,9402184,
+661,smoker,7368816,
+662,blast_furnace,7368816,
+663,cartography_table,9402184,
+664,fletching_table,9402184,
+665,grindstone,10987431,
+666,lectern,9402184,
+667,smithing_table,9402184,
+668,stonecutter,7368816,
+669,bell,10987431,
+670,lantern,10987431,
+671,campfire,9402184,
+672,sweet_berry_bush,31744,
+673,structure_block,10987431,
+674,jigsaw,10987431,
+675,composter,9402184,
+676,bee_nest,9402184,
+677,beehive,9402184,
+678,honey_block,10791096,
+679,honeycomb_block,10791096,
diff --git a/imaginaire/model_utils/gancraft/mc_utils.py b/imaginaire/model_utils/gancraft/mc_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..faa87e04060866541761586ffe8e41bcce9ca44c
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/mc_utils.py
@@ -0,0 +1,388 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import csv
+import time
+# For binary dilation
+from scipy import ndimage
+import os
+from imaginaire.model_utils.gancraft.mc_lbl_reduction import ReducedLabelMapper
+
+
+def load_voxel_new(voxel_path, shape=[256, 512, 512]):
+    voxel_world = np.fromfile(voxel_path, dtype='int32')
+    voxel_world = voxel_world.reshape(
+        shape[1]//16, shape[2]//16, 16, 16, shape[0])
+    voxel_world = voxel_world.transpose(4, 0, 2, 1, 3)
+    voxel_world = voxel_world.reshape(shape[0], shape[1], shape[2])
+    voxel_world = np.ascontiguousarray(voxel_world)
+    voxel_world = torch.from_numpy(voxel_world.astype(np.int32))
+    return voxel_world
+
+
+def gen_corner_voxel(voxel):
+    r"""Converting voxel center array to voxel corner array. The size of the
+    produced array grows by 1 on every dimension.
+
+    Args:
+        voxel (torch.IntTensor, CPU): Input voxel of three dimensions
+    """
+    structure = np.zeros([3, 3, 3], dtype=np.bool)
+    structure[1:, 1:, 1:] = True
+    voxel_p = F.pad(voxel, (0, 1, 0, 1, 0, 1))
+    corners = ndimage.binary_dilation(voxel_p.numpy(), structure)
+    corners = torch.tensor(corners, dtype=torch.int32)
+    return corners
+
+
+def calc_height_map(voxel_t):
+    r"""Calculate height map given a voxel grid [Y, X, Z] as input.
+    The height is defined as the Y index of the surface (non-air) block
+
+    Args:
+        voxel (Y x X x Z torch.IntTensor, CPU): Input voxel of three dimensions
+    Output:
+        heightmap (X x Z torch.IntTensor)
+    """
+    start_time = time.time()
+    m, h = torch.max((torch.flip(voxel_t, [0]) != 0).int(), dim=0, keepdim=False)
+    heightmap = voxel_t.shape[0] - 1 - h
+    heightmap[m == 0] = 0  # Special case when the whole vertical column is empty
+
+    elapsed_time = time.time() - start_time
+    print("[GANcraft-utils] Heightmap time: {}".format(elapsed_time))
+    return heightmap
+
+
+def trans_vec_homo(m, v, is_vec=False):
+    r"""3-dimensional Homogeneous matrix and regular vector multiplication
+    Convert v to homogeneous vector, perform M-V multiplication, and convert back
+    Note that this function does not support autograd.
+
+    Args:
+        m (4 x 4 tensor): a homogeneous matrix
+        v (3 tensor): a 3-d vector
+        vec (bool): if true, v is direction. Otherwise v is point
+    """
+    if is_vec:
+        v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype)
+    else:
+        v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype)
+    v = torch.mv(m, v)
+    if not is_vec:
+        v = v / v[3]
+    v = v[:3]
+    return v
+
+
+def cumsum_exclusive(tensor, dim):
+    cumsum = torch.cumsum(tensor, dim)
+    cumsum = torch.roll(cumsum, 1, dim)
+    cumsum.index_fill_(dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0)
+    return cumsum
+
+
+def sample_depth_batched(depth2, nsamples, deterministic=False, use_box_boundaries=True, sample_depth=4):
+    r"""    Make best effort to sample points within the same distance for every ray.
+    Exception: When there is not enough voxel.
+
+    Args:
+        depth2 (N x 2 x 256 x 256 x 4 x 1 tensor):
+        - N: Batch.
+        - 2: Entrance / exit depth for each intersected box.
+        - 256, 256: Height, Width.
+        - 4: Number of intersected boxes along the ray.
+        - 1: One extra dim for consistent tensor dims.
+        depth2 can include NaNs.
+        deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
+        use_box_boundaries (bool): Whether to add the entrance / exit points into the sample.
+        sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels.
+    """
+
+    bs = depth2.size(0)
+    dim0 = depth2.size(2)
+    dim1 = depth2.size(3)
+    dists = depth2[:, 1] - depth2[:, 0]
+    dists[torch.isnan(dists)] = 0  # N, 256, 256, 4, 1
+    accu_depth = torch.cumsum(dists, dim=-2)  # N, 256, 256, 4, 1
+    total_depth = accu_depth[..., [-1], :]  # N, 256, 256, 1, 1
+
+    total_depth = torch.clamp(total_depth, None, sample_depth)
+
+    # Ignore out of range box boundaries. Fill with random samples.
+    if use_box_boundaries:
+        boundary_samples = accu_depth.clone().detach()
+        boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth
+        bad_mask = (accu_depth > sample_depth) | (dists == 0)
+        boundary_samples[bad_mask] = boundary_samples_filler[bad_mask]
+
+    rand_shape = [bs, dim0, dim1, nsamples, 1]
+    # 256, 256, N, 1
+    if deterministic:
+        rand_samples = torch.empty(rand_shape, dtype=total_depth.dtype, device=total_depth.device)
+        rand_samples[..., :, 0] = torch.linspace(0, 1, nsamples+2)[1:-1]
+    else:
+        rand_samples = torch.rand(rand_shape, dtype=total_depth.dtype, device=total_depth.device)  # 256, 256, N, 1
+        # Stratified sampling as in NeRF
+        rand_samples = rand_samples / nsamples
+        rand_samples[..., :, 0] += torch.linspace(0, 1, nsamples+1, device=rand_samples.device)[:-1]
+    rand_samples = rand_samples * total_depth  # 256, 256, N, 1
+
+    # Can also include boundaries
+    if use_box_boundaries:
+        rand_samples = torch.cat([rand_samples, boundary_samples, torch.zeros(
+            [bs, dim0, dim1, 1, 1], dtype=total_depth.dtype, device=total_depth.device)], dim=-2)
+    rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False)
+
+    midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2
+    new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :]
+
+    # Scatter the random samples back
+    # 256, 256, 1, M, 1 > 256, 256, N, 1, 1
+    idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3)  # 256, 256, M, 1
+    # print(idx.shape, idx.max(), idx.min()) # max 3, min 0
+
+    depth_deltas = depth2[:, 0, :, :, 1:, :] - depth2[:, 1, :, :, :-1, :]  # There might be NaNs!
+    depth_deltas = torch.cumsum(depth_deltas, dim=-2)
+    depth_deltas = torch.cat([depth2[:, 0, :, :, [0], :], depth_deltas+depth2[:, 0, :, :, [0], :]], dim=-2)
+    heads = torch.gather(depth_deltas, -2, idx)  # 256 256 M 1
+    # heads = torch.gather(depth2[0], -2, idx) # 256 256 M 1
+
+    # print(torch.any(torch.isnan(heads)))
+    rand_depth = heads + midpoints  # 256 256 N 1
+
+    return rand_depth, new_dists, idx
+
+
+def volum_rendering_relu(sigma, dists, dim=2):
+    free_energy = F.relu(sigma) * dists
+
+    a = 1 - torch.exp(-free_energy.float())  # probability of it is not empty here
+    b = torch.exp(-cumsum_exclusive(free_energy, dim=dim))  # probability of everything is empty up to now
+    probs = a * b  # probability of the ray hits something here
+
+    return probs
+
+
+class McVoxel(nn.Module):
+    r"""Voxel management."""
+
+    def __init__(self, voxel_t, preproc_ver):
+        super(McVoxel, self).__init__()
+        # Filter voxel
+        voxel_t[voxel_t == 246] = 0  # lily_pad
+        voxel_t[voxel_t == 241] = 0  # vine
+        voxel_t[voxel_t == 611] = 26  # Blue ice -> water
+        voxel_t[voxel_t == 183] = 26  # ice -> water
+        voxel_t[voxel_t == 401] = 25  # Packed ice -> bedrock
+
+        if preproc_ver >= 3 and preproc_ver < 6:
+            voxel_t[voxel_t == 27] = 25  # Lava -> bedrock
+            voxel_t[voxel_t == 616] = 9  # void_air -> dirt
+            voxel_t[voxel_t == 617] = 25  # cave_air -> bedrock
+
+        if preproc_ver >= 6:
+            voxel_t[voxel_t == 616] = 0  # void_air -> air
+            voxel_t[voxel_t == 617] = 0  # cave_air -> air
+
+        # Simplify voxel
+        structure = ndimage.generate_binary_structure(3, 3)
+        mask = voxel_t.numpy() > 0
+        if preproc_ver == 4:  # Hollow bottom
+            mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1)
+            voxel_t[mask] = 0
+        if preproc_ver >= 5:  # Close cell before hollow bottom
+            mask = ndimage.morphology.binary_dilation(mask, iterations=1, border_value=1)
+            mask = ndimage.morphology.binary_erosion(mask, iterations=1, border_value=1)
+            mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1)
+            voxel_t[mask] = 0
+
+        self.register_buffer('voxel_t', voxel_t, persistent=False)
+
+        self.trans_mat = torch.eye(4)  # Transform voxel to world
+        # Generate heightmap for camera trajectory generation
+        self.heightmap = calc_height_map(self.voxel_t)
+        self._truncate_voxel()
+        # Convert voxel ([X, Y, Z], int32) to corner ([X+1, Y+1, Z+1], int32) (Requires CPU tensor)
+        corner_t = gen_corner_voxel(self.voxel_t)
+        self.register_buffer('corner_t', corner_t, persistent=False)
+
+        # Generate 3D position to 1D feature LUT table
+        nfilledvox = torch.sum(self.corner_t > 0)
+        print('[GANcraft-utils] Number of filled voxels: {} / {}'.format(nfilledvox.item(), torch.numel(self.corner_t)))
+        # Zero means non-existent voxel.
+        self.corner_t[self.corner_t > 0] = torch.arange(start=1, end=nfilledvox+1, step=1, dtype=torch.int32)
+        self.nfilledvox = nfilledvox
+
+    def world2local(self, v, is_vec=False):
+        mat_world2local = torch.inverse(self.trans_mat)
+        return trans_vec_homo(mat_world2local, v, is_vec)
+
+    def _truncate_voxel(self):
+        gnd_level = self.heightmap.min()
+        sky_level = self.heightmap.max() + 1
+        self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :]
+        self.trans_mat[0, 3] += gnd_level
+        print('[GANcraft-utils] Voxel truncated. Gnd: {}; Sky: {}.'.format(gnd_level.item(), sky_level.item()))
+
+    def is_sea(self, loc):
+        r"""loc: [2]: x, z."""
+        x = int(loc[1])
+        z = int(loc[2])
+        if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1):
+            print('[McVoxel] is_sea(): Index out of bound.')
+            return True
+        y = self.heightmap[x, z] - self.trans_mat[0, 3]
+        y = int(y)
+        if self.voxel_t[y, x, z] == 26:
+            print('[McVoxel] is_sea(): Get a sea.')
+            print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z])
+            return True
+        else:
+            return False
+
+
+class MCLabelTranslator:
+    r"""Resolving mapping across Minecraft voxel, coco-stuff label and GANcraft reduced label set."""
+
+    def __init__(self):
+        this_path = os.path.dirname(os.path.abspath(__file__))
+        # Load voxel name lut
+        id2name_lut = {}
+        id2color_lut = {}
+        id2glbl_lut = {}
+        with open(os.path.join(this_path, 'id2name_gg.csv'), newline='') as csvfile:
+            csvreader = csv.reader(csvfile, delimiter=',')
+            for row in csvreader:
+                id2name_lut[int(row[0])] = row[1]
+                id2color_lut[int(row[0])] = int(row[2])
+                id2glbl_lut[int(row[0])] = row[3]
+
+        # Load GauGAN color lut
+        glbl2color_lut = {}
+        glbl2cocoidx_lut = {}
+        with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile:
+            csvreader = csv.reader(csvfile, delimiter=',')
+            cocoidx = 1  # 0 is "Others"
+            for row in csvreader:
+                color = int(row[1].lstrip('#'), 16)
+                glbl2color_lut[row[0]] = color
+                glbl2cocoidx_lut[row[0]] = cocoidx
+                cocoidx += 1
+
+        # Generate id2ggcolor lut
+        id2ggcolor_lut = {}
+        for k, v in id2glbl_lut.items():
+            if v:
+                id2ggcolor_lut[k] = glbl2color_lut[v]
+            else:
+                id2ggcolor_lut[k] = 0
+
+        # Generate id2cocoidx
+        id2cocoidx_lut = {}
+        for k, v in id2glbl_lut.items():
+            if v:
+                id2cocoidx_lut[k] = glbl2cocoidx_lut[v]
+            else:
+                id2cocoidx_lut[k] = 0
+
+        self.id2color_lut = id2color_lut
+        self.id2name_lut = id2name_lut
+        self.id2glbl_lut = id2glbl_lut
+        self.id2ggcolor_lut = id2ggcolor_lut
+        self.id2cocoidx_lut = id2cocoidx_lut
+
+        if True:
+            mapper = ReducedLabelMapper()
+            mcid2rdid_lut = mapper.mcid2rdid_lut
+            mcid2rdid_lut = torch.tensor(mcid2rdid_lut, dtype=torch.long)
+            self.mcid2rdid_lut = mcid2rdid_lut
+            self.num_reduced_lbls = len(mapper.reduced_lbls)
+            self.ignore_id = mapper.ignore_id
+            self.dirt_id = mapper.dirt_id
+            self.water_id = mapper.water_id
+
+            self.mapper = mapper
+
+            ggid2rdid_lut = mapper.ggid2rdid + [0]  # Last index is ignore
+            ggid2rdid_lut = torch.tensor(ggid2rdid_lut, dtype=torch.long)
+            self.ggid2rdid_lut = ggid2rdid_lut
+        if True:
+            mc2coco_lut = list(zip(*sorted([(k, v) for k, v in self.id2cocoidx_lut.items()])))[1]
+            mc2coco_lut = torch.tensor(mc2coco_lut, dtype=torch.long)
+            self.mc2coco_lut = mc2coco_lut
+
+    def gglbl2ggid(self, gglbl):
+        return self.mapper.gglbl2ggid[gglbl]
+
+    def mc2coco(self, mc):
+        self.mc2coco_lut = self.mc2coco_lut.to(mc.device)
+        coco = self.mc2coco_lut[mc.long()]
+        return coco
+
+    def mc2reduced(self, mc, ign2dirt=False):
+        self.mcid2rdid_lut = self.mcid2rdid_lut.to(mc.device)
+        reduced = self.mcid2rdid_lut[mc.long()]
+        if ign2dirt:
+            reduced[reduced == self.ignore_id] = self.dirt_id
+        return reduced
+
+    def coco2reduced(self, coco):
+        self.ggid2rdid_lut = self.ggid2rdid_lut.to(coco.device)
+        reduced = self.ggid2rdid_lut[coco.long()]
+        return reduced
+
+    def get_num_reduced_lbls(self):
+        return self.num_reduced_lbls
+
+    @staticmethod
+    def uint32_to_4uint8(x):
+        dt1 = np.dtype(('i4', [('bytes', 'u1', 4)]))
+        color = x.view(dtype=dt1)['bytes']
+        return color
+
+    def mc_color(self, img):
+        r"""Obtaining Minecraft default color.
+
+        Args:
+            img (H x W x 1 int32 numpy tensor): Segmentation map.
+        """
+        lut = self.id2color_lut
+        lut = list(zip(*sorted([(k, v) for k, v in lut.items()])))[1]
+        lut = np.array(lut, dtype=np.uint32)
+        rgb = lut[img]
+        rgb = self.uint32_to_4uint8(rgb)[..., :3]
+
+        return rgb
+
+
+def rand_crop(cam_c, cam_res, target_res):
+    r"""Produces a new cam_c so that the effect of rendering with the new cam_c and target_res is the same as rendering
+    with the old parameters and then crop out target_res.
+    """
+    d0 = np.random.randint(cam_res[0] - target_res[0] + 1)
+    d1 = np.random.randint(cam_res[1] - target_res[1] + 1)
+    cam_c = [cam_c[0]-d0, cam_c[1]-d1]
+    return cam_c
+
+
+def segmask_smooth(seg_mask, kernel_size=7):
+    labels = F.avg_pool2d(seg_mask, kernel_size, 1, kernel_size//2)
+    onehot_idx = torch.argmax(labels, dim=1, keepdims=True)
+    labels.fill_(0.0)
+    labels.scatter_(1, onehot_idx, 1.0)
+    return labels
+
+
+def colormap(x, cmap='viridis'):
+    x = np.nan_to_num(x, np.nan, np.nan, np.nan)
+    x = x - np.nanmin(x)
+    x = x / np.nanmax(x)
+    rgb = plt.get_cmap(cmap)(x)[..., :3]
+    return rgb
diff --git a/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv b/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv
new file mode 100644
index 0000000000000000000000000000000000000000..c82cc05572bbace78643911e9789f4f2cfd15f0e
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv
@@ -0,0 +1,12 @@
+ignore
+sky
+tree
+dirt
+flower
+grass
+gravel
+water
+rock
+stone
+sand
+snow
\ No newline at end of file
diff --git a/imaginaire/model_utils/gancraft/voxlib/Makefile b/imaginaire/model_utils/gancraft/voxlib/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..de903af09d2feda89118edc743048d70ce963d24
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/Makefile
@@ -0,0 +1,11 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+all:
+	python setup.py build_ext --inplace
+	python setup.py install
+
+clean:
+	rm -rf *.o *.a *.so test build
diff --git a/imaginaire/model_utils/gancraft/voxlib/__init__.py b/imaginaire/model_utils/gancraft/voxlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fce15c92b99ef0ad9feaeb31d75664a0971b385
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .positional_encoding import positional_encoding
+from .sp_trilinear import sparse_trilinear_interp_worldcoord
+from voxlib import ray_voxel_intersection_perspective
diff --git a/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py b/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef95d0bd47103233e1c4f32c70dd5b463c9eac1d
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py
@@ -0,0 +1,63 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch.autograd import Function
+import voxlib
+
+# Cheatsheet:
+# mark_dirty() must be used to mark any input that is modified inplace by the forward function.
+# mark_non_differentiable()
+
+
+class PositionalEncodingFunction(Function):
+    @staticmethod
+    def forward(ctx, in_feature, pe_degrees, dim, incl_orig):
+        out_feature = voxlib.positional_encoding(in_feature, pe_degrees, dim, incl_orig)
+
+        ctx.save_for_backward(out_feature)
+        ctx.pe_degrees = pe_degrees
+        ctx.dim = dim
+        ctx.incl_orig = incl_orig
+
+        return out_feature
+
+    @staticmethod
+    def backward(ctx, out_feature_grad):
+        out_feature, = ctx.saved_tensors
+
+        # torch::Tensor positional_encoding_backward(const torch::Tensor& out_feature_grad,
+        # const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) {
+        in_feature_grad = voxlib.positional_encoding_backward(
+            out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig)
+
+        return in_feature_grad, None, None, None
+
+
+def positional_encoding(in_feature, pe_degrees, dim=-1, incl_orig=False):
+    return PositionalEncodingFunction.apply(in_feature, pe_degrees, dim, incl_orig)
+
+# input: N, C
+# output: N, pe_degrees*C
+
+
+def positional_encoding_pt(pts, pe_degrees, dim=-1, incl_orig=False):
+    import numpy as np
+    pe_stor = []
+    for i in range(pe_degrees):
+        pe_stor.append(torch.sin(pts * np.pi * 2 ** i))
+        pe_stor.append(torch.cos(pts * np.pi * 2 ** i))
+    if incl_orig:
+        pe_stor.append(pts)
+    pe = torch.cat(pe_stor, dim=dim)
+    return pe
+
+
+if __name__ == '__main__':
+    x = torch.rand(384, 512, 5, 48).cuda() * 1024
+    y = positional_encoding_pt(x, 4, incl_orig=True)
+    y2 = positional_encoding(x, 4, incl_orig=True)
+
+    print(torch.abs(y - y2))
+    print(torch.allclose(y, y2, rtol=1e-05, atol=1e-05))
diff --git a/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu b/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..278fc165991d70aae549dd376b979f53703a0647
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu
@@ -0,0 +1,285 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+
+#include <torch/types.h>
+
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <math_constants.h>
+#include <time.h>
+
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+#include <vector>
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+
+struct PE_Params {
+    int ndegrees;
+    int pre_size;
+    int post_size;
+    bool incl_orig;
+};
+
+// const int TILE_DIM_X = 16;  // channel dim
+// const int TILE_DIM_Y = 64;  // entry dim
+// dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+TILE_DIM_Y-1)/TILE_DIM_Y, 1);
+// dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1);
+template <int TILE_DIM_X, int TILE_DIM_Y, int DUP_Y>
+__global__ void positional_encoding_kernel(
+    float* __restrict__ out_feature,
+    const float* __restrict__ in_feature, const PE_Params p) {
+
+    const int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x;
+    const int idx_entry_base = blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y;
+    if (idx_feat >= p.post_size) {
+        return;
+    }
+
+    int stride = p.ndegrees*2;
+    if (p.incl_orig) {
+        stride += 1;
+    }
+
+    for (int j=0; j<DUP_Y; j++) {
+        int idx_entry = idx_entry_base + j;
+        if (idx_entry >= p.pre_size) {
+            return;
+        }
+        float data = in_feature[idx_entry*p.post_size + idx_feat];
+
+        for (int i=0; i<p.ndegrees; i++) {
+            float rad = data * CUDART_PI_F * exp2f(i);
+            //float rad = scalbnf(data * CUDART_PI_F, i);
+            float sinrad, cosrad;
+            sincosf(rad, &sinrad, &cosrad);
+            out_feature[idx_entry*p.post_size*stride + i*2*p.post_size + idx_feat] = sinrad;
+            out_feature[idx_entry*p.post_size*stride + (i*2+1)*p.post_size + idx_feat] = cosrad;
+        }
+        if (p.incl_orig) {
+            out_feature[idx_entry*p.post_size*stride + (stride-1)*p.post_size + idx_feat] = data;
+        }
+    }
+}
+
+template <int TILE_DIM_X, int TILE_DIM_Y, int DUP_Y>
+__global__ void positional_encoding_backward_kernel(
+    float* __restrict__ in_feature_grad,
+    const float* __restrict__ out_feature_grad, const float* __restrict__ out_feature, const PE_Params p) {
+
+    int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x;
+    const int idx_entry_base = blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y;
+
+    if (idx_feat >= p.post_size) {
+        return;
+    }
+
+    int stride = p.ndegrees*2;
+    if (p.incl_orig) {
+        stride += 1;
+    }
+
+    for (int j=0; j<DUP_Y; j++) {
+        int idx_entry = idx_entry_base + j;
+        if (idx_entry >= p.pre_size) {
+            return;
+        }
+
+        float grad = 0.0f;
+        for (int i=0; i<p.ndegrees; i++) {
+            float grad_t;
+
+            grad_t = out_feature_grad[idx_entry*p.post_size*stride + i*2*p.post_size + idx_feat] *
+                out_feature[idx_entry*p.post_size*stride + (i*2+1)*p.post_size + idx_feat];        // cos(x*pi*(2^i))
+
+            grad_t -= out_feature_grad[idx_entry*p.post_size*stride + (i*2+1)*p.post_size + idx_feat] *
+                out_feature[idx_entry*p.post_size*stride + (i*2)*p.post_size + idx_feat];        // -sin(x*pi*(2^i))
+
+            grad += grad_t * CUDART_PI_F * exp2f(i);
+        }
+        if (p.incl_orig) {
+            grad += out_feature_grad[idx_entry*p.post_size*stride + (stride-1)*p.post_size + idx_feat];
+        }
+
+        in_feature_grad[idx_entry*p.post_size + idx_feat] = grad;
+    }
+}
+
+
+// Input:
+//      in_feature:     float32 [..., N, ...]
+//      ndegree:        int32   Degrees of PE encoding
+//      dim:            int32   Dimension to concatenate
+//      incl_orig:      bool    Whether to include original feature vector or not
+// Output:
+//      out_feature:     float32 [..., N*ndegree*2+incl_orig, ...]
+// std::vector<torch::Tensor>
+torch::Tensor positional_encoding_cuda(const torch::Tensor& in_feature, int ndegrees, int dim, bool incl_orig) {
+    CHECK_CUDA(in_feature);
+
+    int curDevice = -1;
+    cudaGetDevice(&curDevice);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+    torch::Device device = in_feature.device();
+
+    assert(in_feature.dtype() == torch::kFloat32);
+
+    // Handle negative index
+    if (dim < 0) {
+        dim = in_feature.dim() + dim;
+    }
+    assert(dim >= 0 && dim < in_feature.dim());
+
+    // No need to be contiguous. Input and output has the same memory layout.
+    CHECK_CONTIGUOUS(in_feature);
+
+    PE_Params p;
+    p.ndegrees = ndegrees;
+    p.incl_orig = incl_orig;
+
+    // This only works for contiguous tensors...
+    int pre_size = 1;
+    int post_size = 1;
+    for (int i=0; i<dim; i++) {
+        pre_size *= in_feature.size(i);
+    }
+    for (int i=dim; i<in_feature.dim(); i++) {
+        post_size *= in_feature.size(i);
+    }
+    p.pre_size = pre_size;
+    p.post_size = post_size;
+
+    // Calculate output shape
+    std::vector<int64_t> out_feature_shape;
+    for (int i=0; i<in_feature.dim(); i++) {
+        int64_t dim_t = in_feature.size(i);
+        if (i == dim) {
+            if (incl_orig) {
+                dim_t = dim_t*(ndegrees*2+1);
+            } else {
+                dim_t = dim_t*ndegrees*2;
+            }
+        }
+        out_feature_shape.push_back(dim_t);
+    }
+
+    // Always produce contiguous output
+    torch::Tensor out_feature = torch::empty(out_feature_shape, torch::TensorOptions().dtype(torch::kFloat32).device(device));
+
+    // Launch CUDA kernel
+    // Case 1: Concat at the last dimension (post_size < pre_size)  -->  Each thread handle a single post_size
+    // Case 2: Concat at the middle (post_size > pre_size)  -->  Each thread handle
+    const int TILE_DIM_X = 16;  // channel dim
+    const int TILE_DIM_Y = 64;  // entry dim
+    //const int DUP_Y = 4; // Each thread handle multiple entries to save threads
+    const int DUP_Y = 8; // DGXA 64 samples per ray @ 256x256
+    dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+(TILE_DIM_Y*DUP_Y)-1)/(TILE_DIM_Y*DUP_Y), 1);
+    dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1);
+    positional_encoding_kernel<TILE_DIM_X, TILE_DIM_Y, DUP_Y><<<dimGrid, dimBlock, 0, stream>>>(
+        out_feature.data_ptr<float>(),
+        in_feature.data_ptr<float>(), p
+    );
+
+    THCudaCheck(cudaGetLastError());
+    return out_feature;
+}
+
+//in_feature_grad = voxrender_op.positional_encoding_backward(out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig);
+// Input:
+//      out_feature_grad:   float32 [..., N*ndegree*2+incl_orig, ...]
+//      out_feature:        float32 [..., N*ndegree*2+incl_orig, ...]
+//      ndegrees:           int32   Degrees of PE encoding
+//      dim:                int32   Dimension to concatenate
+//      incl_orig:          bool    Whether to include original feature vector or not
+// Output:
+//      in_feature_grad:    float32 [..., N, ...]
+// std::vector<torch::Tensor>
+torch::Tensor positional_encoding_backward_cuda(const torch::Tensor& out_feature_grad_, const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) {
+    CHECK_CUDA(out_feature_grad_);
+    CHECK_CUDA(out_feature);
+
+    const torch::Tensor out_feature_grad = out_feature_grad_.contiguous();
+
+    int curDevice = -1;
+    cudaGetDevice(&curDevice);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+    torch::Device device = out_feature_grad.device();
+
+    assert(out_feature_grad.dtype() == torch::kFloat32);
+    assert(out_feature.dtype() == torch::kFloat32);
+    assert(out_feature_grad.sizes() == out_feature.sizes());
+
+    // Handle negative index
+    if (dim < 0) {
+        dim = out_feature.dim() + dim;
+    }
+    assert(dim >= 0 && dim < out_feature.dim());
+
+    CHECK_CONTIGUOUS(out_feature_grad);
+    CHECK_CONTIGUOUS(out_feature);
+
+    PE_Params p;
+    p.ndegrees = ndegrees;
+    p.incl_orig = incl_orig;
+
+    int expansion_factor = ndegrees*2;
+    if (incl_orig) {
+        expansion_factor += 1;
+    }
+    // This only works for contiguous tensors...
+    int pre_size = 1;
+    int post_size = 1;
+    for (int i=0; i<dim; i++) {
+        pre_size *= out_feature.size(i);
+    }
+    for (int i=dim; i<out_feature.dim(); i++) {
+        post_size *= out_feature.size(i);
+    }
+    post_size = post_size / expansion_factor;
+    p.pre_size = pre_size;
+    p.post_size = post_size;
+
+    // Calculate output shape
+    std::vector<int64_t> out_feature_shape;
+    for (int i=0; i<out_feature.dim(); i++) {
+        int64_t dim_t = out_feature.size(i);
+        if (i == dim) {
+            dim_t = dim_t / expansion_factor;
+        }
+        out_feature_shape.push_back(dim_t);
+    }
+
+    // Always produce contiguous output
+    torch::Tensor in_feature_grad = torch::empty(out_feature_shape, torch::TensorOptions().dtype(torch::kFloat32).device(device));
+
+
+    // Launch CUDA kernel
+    // Case 1: Concat at the last dimension (post_size < pre_size)  -->  Each thread handle a single post_size
+    // Case 2: Concat at the middle (post_size > pre_size)  -->  Each thread handle
+    const int TILE_DIM_X = 16;  // channel dim
+    const int TILE_DIM_Y = 64;  // entry dim
+    //const int DUP_Y = 4; // Nothing to amortize
+    const int DUP_Y = 8; // DGXA
+    dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+(TILE_DIM_Y*DUP_Y)-1)/(TILE_DIM_Y*DUP_Y), 1);
+    dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1);
+    positional_encoding_backward_kernel<TILE_DIM_X, TILE_DIM_Y, DUP_Y><<<dimGrid, dimBlock, 0, stream>>>(
+        in_feature_grad.data_ptr<float>(),
+        out_feature_grad.data_ptr<float>(), out_feature.data_ptr<float>(), p
+    );
+
+    THCudaCheck(cudaGetLastError());
+
+    return in_feature_grad;
+}
diff --git a/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu b/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7ef22dc309e2eb6d944c50d917235f0c62219cb6
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu
@@ -0,0 +1,325 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+//
+// The ray marching algorithm used in this file is a variety of modified Bresenham method:
+// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.42.3443&rep=rep1&type=pdf
+// Search for "voxel traversal algorithm" for related information
+
+#include <torch/types.h>
+
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <curand.h>
+#include <curand_kernel.h>
+#include <time.h>
+
+//#include <pybind11/numpy.h>
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+#include <vector>
+
+#include "voxlib_common.h"
+
+struct RVIP_Params {
+    int voxel_dims[3];
+    int voxel_strides[3];
+    int max_samples;
+    int img_dims[2];
+    // Camera parameters
+    float cam_ori[3];
+    float cam_fwd[3];
+    float cam_side[3];
+    float cam_up[3];
+    float cam_c[2];
+    float cam_f;
+    //unsigned long seed;
+};
+
+/*
+    out_voxel_id: torch CUDA int32  [   img_dims[0], img_dims[1], max_samples, 1]
+    out_depth:    torch CUDA float  [2, img_dims[0], img_dims[1], max_samples, 1]
+    out_raydirs:  torch CUDA float  [   img_dims[0], img_dims[1],           1, 3]
+    Image coordinates refer to the center of the pixel
+    [0, 0, 0] at voxel coordinate is at the corner of the corner block (instead of at the center)
+*/
+template <int TILE_DIM>
+static __global__ void ray_voxel_intersection_perspective_kernel(int32_t* __restrict__ out_voxel_id, float* __restrict__ out_depth, float* __restrict__ out_raydirs,
+const int32_t* __restrict__ in_voxel, const RVIP_Params p) {
+
+    int img_coords[2];
+    img_coords[1] = blockIdx.x*TILE_DIM+threadIdx.x;
+    img_coords[0] = blockIdx.y*TILE_DIM+threadIdx.y;
+    if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) {
+        return;
+    }
+    int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1];
+
+    // Calculate ray origin and direction
+    float rayori[3], raydir[3];
+    rayori[0] = p.cam_ori[0];
+    rayori[1] = p.cam_ori[1];
+    rayori[2] = p.cam_ori[2];
+
+    // Camera intrinsics
+    float ndc_imcoords[2];
+    ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height
+    ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1];
+
+    raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] + p.cam_fwd[0] * p.cam_f;
+    raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] + p.cam_fwd[1] * p.cam_f;
+    raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] + p.cam_fwd[2] * p.cam_f;
+    normalize<float, 3>(raydir);
+
+    // Save out_raydirs
+    out_raydirs[pix_index*3] = raydir[0];
+    out_raydirs[pix_index*3+1] = raydir[1];
+    out_raydirs[pix_index*3+2] = raydir[2];
+
+    float axis_t[3];
+    int axis_int[3];
+    //int axis_intbound[3];
+
+    // Current voxel
+    axis_int[0] = floorf(rayori[0]);
+    axis_int[1] = floorf(rayori[1]);
+    axis_int[2] = floorf(rayori[2]);
+
+    #pragma unroll
+    for (int i=0; i<3; i++) {
+        if (raydir[i] > 0) {
+            // Initial t value
+            // Handle boundary case where rayori[i] is a whole number. Always round Up for the next block
+            //axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) / raydir[i];
+            axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i];
+        } else if (raydir[i] < 0) {
+            axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
+        } else {
+            axis_t[i] = HUGE_VALF;
+        }
+    }
+
+    // Fused raymarching and sampling
+    bool quit = false;
+    for (int cur_plane=0; cur_plane < p.max_samples; cur_plane++) { // Last cycle is for calculating p2
+        float t = nanf("0");
+        float t2 = nanf("0");
+        int32_t blk_id = 0;
+        // Find the next intersection
+        while (!quit) {
+            // Find the next smallest t
+            float tnow;
+            /*
+            #pragma unroll
+            for (int i=0; i<3; i++) {
+                if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
+                    // Update current t
+                    tnow = axis_t[i];
+                    // Update t candidates
+                    if (raydir[i] > 0) {
+                        axis_int[i] += 1;
+                        if (axis_int[i] >= p.voxel_dims[i]) {
+                            quit = true;
+                        }
+                        axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i];
+                    } else {
+                        axis_int[i] -= 1;
+                        if (axis_int[i] < 0) {
+                            quit = true;
+                        }
+                        axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
+                    }
+                    break; // Avoid advancing multiple steps as axis_t is updated
+                }
+            }
+            */
+            // Hand unroll
+            if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
+                // Update current t
+                tnow = axis_t[0];
+                // Update t candidates
+                if (raydir[0] > 0) {
+                    axis_int[0] += 1;
+                    if (axis_int[0] >= p.voxel_dims[0]) {
+                        quit = true;
+                    }
+                    axis_t[0] = ((float)(axis_int[0]+1) - rayori[0]) / raydir[0];
+                } else {
+                    axis_int[0] -= 1;
+                    if (axis_int[0] < 0) {
+                        quit = true;
+                    }
+                    axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0];
+                }
+            } else if (axis_t[1] <= axis_t[2]) {
+                tnow = axis_t[1];
+                if (raydir[1] > 0) {
+                    axis_int[1] += 1;
+                    if (axis_int[1] >= p.voxel_dims[1]) {
+                        quit = true;
+                    }
+                    axis_t[1] = ((float)(axis_int[1]+1) - rayori[1]) / raydir[1];
+                } else {
+                    axis_int[1] -= 1;
+                    if (axis_int[1] < 0) {
+                        quit = true;
+                    }
+                    axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1];
+                }
+            } else {
+                tnow = axis_t[2];
+                if (raydir[2] > 0) {
+                    axis_int[2] += 1;
+                    if (axis_int[2] >= p.voxel_dims[2]) {
+                        quit = true;
+                    }
+                    axis_t[2] = ((float)(axis_int[2]+1) - rayori[2]) / raydir[2];
+                } else {
+                    axis_int[2] -= 1;
+                    if (axis_int[2] < 0) {
+                        quit = true;
+                    }
+                    axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2];
+                }
+            }
+
+            if (quit) {
+                break;
+            }
+
+            // Skip empty space
+            // Could there be deadlock if the ray direction is away from the world?
+            if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] || axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] || axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) {
+                continue;
+            }
+
+            // Test intersection using voxel grid
+            blk_id = in_voxel[axis_int[0]*p.voxel_strides[0] + axis_int[1]*p.voxel_strides[1] + axis_int[2]*p.voxel_strides[2]];
+            if (blk_id == 0) {
+                continue;
+            }
+
+            // Now that there is an intersection
+            t = tnow;
+            // Calculate t2
+            /*
+            #pragma unroll
+            for (int i=0; i<3; i++) {
+                if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
+                    t2 = axis_t[i];
+                    break;
+                }
+            }
+            */
+            // Hand unroll
+            if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
+                t2 = axis_t[0];
+            } else if (axis_t[1] <= axis_t[2]) {
+                t2 = axis_t[1];
+            } else {
+                t2 = axis_t[2];
+            }
+            break;
+        } // while !quit (ray marching loop)
+
+        out_depth[pix_index*p.max_samples+cur_plane] = t;
+        out_depth[p.img_dims[0]*p.img_dims[1]*p.max_samples + pix_index*p.max_samples+cur_plane] = t2;
+        out_voxel_id[pix_index*p.max_samples+cur_plane] = blk_id;
+    } // cur_plane
+}
+
+
+/*
+    out:
+        out_voxel_id: torch CUDA int32  [   img_dims[0], img_dims[1], max_samples, 1]
+        out_depth:    torch CUDA float  [2, img_dims[0], img_dims[1], max_samples, 1]
+        out_raydirs:  torch CUDA float  [   img_dims[0], img_dims[1],           1, 3]
+    in:
+        in_voxel:     torch CUDA int32  [X, Y, Z] [40, 512, 512]
+        cam_ori:      torch      float  [3]
+        cam_dir:      torch      float  [3]
+        cam_up:       torch      float  [3]
+        cam_f:                   float
+        cam_c:                   int    [2]
+        img_dims:                int    [2]
+        max_samples:             int
+*/
+std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(const torch::Tensor& in_voxel, const torch::Tensor& cam_ori, const torch::Tensor& cam_dir, const torch::Tensor& cam_up, float cam_f, const std::vector<float>& cam_c, const std::vector<int>& img_dims, int max_samples) {
+    CHECK_CUDA(in_voxel);
+
+    int curDevice = -1;
+    cudaGetDevice(&curDevice);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+    torch::Device device = in_voxel.device();
+
+    //assert(in_voxel.dtype() == torch::kU8);
+    assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility
+    assert(in_voxel.dim() == 3);
+    assert(cam_ori.dtype() == torch::kFloat32);
+    assert(cam_ori.numel() == 3);
+    assert(cam_dir.dtype() == torch::kFloat32);
+    assert(cam_dir.numel() == 3);
+    assert(cam_up.dtype() == torch::kFloat32);
+    assert(cam_up.numel() == 3);
+    assert(img_dims.size() == 2);
+
+    RVIP_Params p;
+
+    // Calculate camera rays
+    const torch::Tensor cam_ori_c = cam_ori.cpu();
+    const torch::Tensor cam_dir_c = cam_dir.cpu();
+    const torch::Tensor cam_up_c = cam_up.cpu();
+
+    // Get the coordinate frame of camera space in world space
+    normalize<float, 3>(p.cam_fwd, cam_dir_c.data_ptr<float>());
+    cross<float>(p.cam_side, p.cam_fwd, cam_up_c.data_ptr<float>());
+    normalize<float, 3>(p.cam_side);
+    cross<float>(p.cam_up, p.cam_side, p.cam_fwd);
+    normalize<float, 3>(p.cam_up); // Not absolutely necessary as both vectors are normalized. But just in case...
+
+    copyarr<float, 3>(p.cam_ori, cam_ori_c.data_ptr<float>());
+
+    p.cam_f = cam_f;
+    p.cam_c[0] = cam_c[0];
+    p.cam_c[1] = cam_c[1];
+    p.max_samples = max_samples;
+    //printf("[Renderer] max_dist: %ld\n", max_dist);
+
+    p.voxel_dims[0] = in_voxel.size(0);
+    p.voxel_dims[1] = in_voxel.size(1);
+    p.voxel_dims[2] = in_voxel.size(2);
+    p.voxel_strides[0] = in_voxel.stride(0);
+    p.voxel_strides[1] = in_voxel.stride(1);
+    p.voxel_strides[2] = in_voxel.stride(2);
+
+    //printf("[Renderer] Voxel resolution: %ld, %ld, %ld\n", p.voxel_dims[0], p.voxel_dims[1], p.voxel_dims[2]);
+
+    p.img_dims[0] = img_dims[0];
+    p.img_dims[1] = img_dims[1];
+
+    // Create output tensors
+    // For Minecraft Seg Mask
+    torch::Tensor out_voxel_id = torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1}, torch::TensorOptions().dtype(torch::kInt32).device(device));
+
+    torch::Tensor out_depth;
+    // Produce two sets of localcoords, one for entry point, the other one for exit point. They share the same corner_ids.
+    out_depth = torch::empty({2, p.img_dims[0], p.img_dims[1], p.max_samples, 1}, torch::TensorOptions().dtype(torch::kFloat32).device(device));
+
+    torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3}, torch::TensorOptions().dtype(torch::kFloat32).device(device).requires_grad(false));
+
+    const int TILE_DIM = 8;
+    dim3 dimGrid((p.img_dims[1]+TILE_DIM-1)/TILE_DIM, (p.img_dims[0]+TILE_DIM-1)/TILE_DIM, 1);
+    dim3 dimBlock(TILE_DIM, TILE_DIM, 1);
+
+    ray_voxel_intersection_perspective_kernel<TILE_DIM><<<dimGrid, dimBlock, 0, stream>>>(
+        out_voxel_id.data_ptr<int32_t>(), out_depth.data_ptr<float>(), out_raydirs.data_ptr<float>(), in_voxel.data_ptr<int32_t>(), p
+    );
+
+    return {out_voxel_id, out_depth, out_raydirs};
+}
diff --git a/imaginaire/model_utils/gancraft/voxlib/setup.py b/imaginaire/model_utils/gancraft/voxlib/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eca848211370c8ddf9dda55d5c67804a73061e9
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/setup.py
@@ -0,0 +1,25 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+cxx_args = ['-fopenmp']
+nvcc_args = []
+
+setup(
+    name='voxrender',
+    ext_modules=[
+        CUDAExtension('voxlib', [
+            'voxlib.cpp',
+            'ray_voxel_intersection.cu',
+            'sp_trilinear_worldcoord_kernel.cu',
+            'positional_encoding_kernel.cu'
+        ],
+            extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}
+        )
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bad56fb23f6b8e2a8e41573f8b6b85f9a5693f1
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py
@@ -0,0 +1,35 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from torch.autograd import Function
+import voxlib
+
+"""
+It takes world coordinate as input instead of block-local coordinate. Corner IDs are looked up on-the-fly to
+save memory.
+"""
+
+
+class SparseTrilinearWorldCoordFunction(Function):
+    @staticmethod
+    def forward(ctx, in_feature, corner_lut_t, in_worldcoord, ign_zero):
+
+        out_feature = voxlib.sp_trilinear_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero, -1)
+        ctx.ign_zero = ign_zero
+        ctx.save_for_backward(in_feature, corner_lut_t, in_worldcoord)
+
+        return out_feature
+
+    @staticmethod
+    def backward(ctx, out_feature_grad):
+        in_feature, corner_lut_t, in_worldcoord = ctx.saved_tensors
+
+        assert ctx.needs_input_grad[2] is False
+        in_feature_grad, = voxlib.sp_trilinear_worldcoord_backward(
+            out_feature_grad, in_feature, corner_lut_t, in_worldcoord, ctx.ign_zero, False)
+        return in_feature_grad, None, None, None, None
+
+
+def sparse_trilinear_interp_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero=False):
+    return SparseTrilinearWorldCoordFunction.apply(in_feature, corner_lut_t, in_worldcoord, ign_zero)
diff --git a/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..403d01fbda3f528211dd262d019a606f1f8f1640
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu
@@ -0,0 +1,527 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+//
+// Fast routine for sparse tri-linear interpolation of high dimensional features.
+// Ignore label is supported.
+
+
+#include <torch/types.h>
+
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <time.h>
+
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+#include <vector>
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+
+struct SpTrilinear_wc_Params {
+    int in_feature_dim;
+    int in_feature_numentries;
+    int corner_lut_dims[3];
+    int corner_lut_strides[3];
+    int in_worldcoord_dims[8];
+    int in_worldcoord_strides[8];
+    int in_worldcoord_ndim;
+    int out_feature_dims[8];
+    int out_feature_strides[8];
+    bool ign_zero;
+};
+
+
+// out_feature.data_ptr<float>(),
+// in_feature.data_ptr<float>(), corner_lut_t.data_ptr<int32_t>(), in_worldcoord.data_ptr<float>(), p
+template <int TILE_DIM_X, int TILE_DIM_Y, int DUP_X>
+__global__ void sp_trilinear_worldcoord_kernel(
+    float* __restrict__ out_feature,
+    const float* __restrict__ in_feature, const int32_t* __restrict__ corner_lut_t, const float* __restrict__ in_worldcoord, SpTrilinear_wc_Params p) {
+
+    const int GRID_X = gridDim.y;
+    int idx_entry = blockIdx.x * TILE_DIM_Y + threadIdx.y;
+
+    // Index processing
+    //int index[7];
+    int t = idx_entry;
+    int idx_in_worldcoord = 0;
+    int idx_out_feature = 0;
+    for (int i=p.in_worldcoord_ndim-2; i>=0; i--) {
+        int idx_t = t % p.in_worldcoord_dims[i];
+        t = t / p.in_worldcoord_dims[i];
+        idx_in_worldcoord += p.in_worldcoord_strides[i] * idx_t;
+        idx_out_feature += p.out_feature_strides[i] * idx_t;
+    }
+    if (t > 0) {
+        return;
+    }
+    int stride_in_worldcoord = p.in_worldcoord_strides[p.in_worldcoord_ndim-1];
+    int stride_out_feature = p.out_feature_strides[p.in_worldcoord_ndim-1];
+
+
+    float world_coords[3];
+    world_coords[0] = in_worldcoord[idx_in_worldcoord];
+    world_coords[1] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord];
+    world_coords[2] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord*2];
+
+    float local_coords[3];
+    int vox_coords[3];
+    local_coords[0] = world_coords[0] - floorf(world_coords[0]);
+    vox_coords[0] = (int)floorf(world_coords[0]);
+    local_coords[1] = world_coords[1] - floorf(world_coords[1]);
+    vox_coords[1] = (int)floorf(world_coords[1]);
+    local_coords[2] = world_coords[2] - floorf(world_coords[2]);
+    vox_coords[2] = (int)floorf(world_coords[2]);
+
+    float interp_weight[8];
+    // 0,0,0
+    interp_weight[0] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]);
+    // 0,0,1
+    interp_weight[1] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]);
+    // 0,1,0
+    interp_weight[2] = (1.0f-local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]);
+    // 0,1,1
+    interp_weight[3] = (1.0f-local_coords[0])*(local_coords[1])*(local_coords[2]);
+    // 1,0,0
+    interp_weight[4] = (local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]);
+    // 1,0,1
+    interp_weight[5] = (local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]);
+    // 1,1,0
+    interp_weight[6] = (local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]);
+    // 1,1,1
+    interp_weight[7] = (local_coords[0])*(local_coords[1])*(local_coords[2]);
+
+    int indices[8];
+    // Hard boundary check (zero padding)
+    if (isnan(world_coords[0]) || isnan(world_coords[1]) || isnan(world_coords[2])) {
+        indices[0] = -1;indices[1] = -1;indices[2] = -1;indices[3] = -1;
+        indices[4] = -1;indices[5] = -1;indices[6] = -1;indices[7] = -1;
+    } else {
+        // Clamp to boundaries
+        int vox_coords_1[3];
+        vox_coords_1[0] = min(max(vox_coords[0]+1, 0), p.corner_lut_dims[0]-1);
+        vox_coords_1[1] = min(max(vox_coords[1]+1, 0), p.corner_lut_dims[1]-1);
+        vox_coords_1[2] = min(max(vox_coords[2]+1, 0), p.corner_lut_dims[2]-1);
+        vox_coords[0] = min(max(vox_coords[0], 0), p.corner_lut_dims[0]-1);
+        vox_coords[1] = min(max(vox_coords[1], 0), p.corner_lut_dims[1]-1);
+        vox_coords[2] = min(max(vox_coords[2], 0), p.corner_lut_dims[2]-1);
+        int idx_corner_lut;
+        // 000
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+                         p.corner_lut_strides[1] * vox_coords[1] +
+                         p.corner_lut_strides[2] * vox_coords[2];
+        indices[0] = corner_lut_t[idx_corner_lut];
+        // 001
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+                         p.corner_lut_strides[1] * vox_coords[1] +
+                         p.corner_lut_strides[2] * vox_coords_1[2];
+        indices[1] = corner_lut_t[idx_corner_lut];
+        // 010
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+                         p.corner_lut_strides[1] * vox_coords_1[1] +
+                         p.corner_lut_strides[2] * vox_coords[2];
+        indices[2] = corner_lut_t[idx_corner_lut];
+        // 011
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+                         p.corner_lut_strides[1] * vox_coords_1[1] +
+                         p.corner_lut_strides[2] * vox_coords_1[2];
+        indices[3] = corner_lut_t[idx_corner_lut];
+        // 100
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+                         p.corner_lut_strides[1] * vox_coords[1] +
+                         p.corner_lut_strides[2] * vox_coords[2];
+        indices[4] = corner_lut_t[idx_corner_lut];
+        // 101
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+                         p.corner_lut_strides[1] * vox_coords[1] +
+                         p.corner_lut_strides[2] * vox_coords_1[2];
+        indices[5] = corner_lut_t[idx_corner_lut];
+        // 110
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+                         p.corner_lut_strides[1] * vox_coords_1[1] +
+                         p.corner_lut_strides[2] * vox_coords[2];
+        indices[6] = corner_lut_t[idx_corner_lut];
+        // 111
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+                         p.corner_lut_strides[1] * vox_coords_1[1] +
+                         p.corner_lut_strides[2] * vox_coords_1[2];
+        indices[7] = corner_lut_t[idx_corner_lut];
+    }
+
+    if (p.ign_zero) {
+        // Zero indices are to be ignored
+#pragma unroll
+        for (int i=0; i<8; i++) {
+            indices[i] -= 1;
+        }
+    }
+
+    //int idx_feat = blockIdx.x * TILE_DIM_X * DUP_X + threadIdx.x;
+    int idx_feat = blockIdx.y * TILE_DIM_X + threadIdx.x;
+    for (int i=0; i<DUP_X; i++) {
+        if (idx_feat >= p.in_feature_dim) {
+            return;
+        }
+        float interp_feat = 0.0f;
+#pragma unroll
+        for (int j=0; j<8; j++) {
+            if (indices[j] >= 0) {
+                interp_feat = fmaf(in_feature[indices[j]*p.in_feature_dim+idx_feat], interp_weight[j], interp_feat);
+            }
+        }
+        //out_feature[idx_entry*p.in_feature_dim+idx_feat] = interp_feat;
+        out_feature[idx_out_feature+stride_out_feature*idx_feat] = interp_feat;
+        //idx_feat += TILE_DIM_X;
+        idx_feat += TILE_DIM_X * GRID_X;
+    }
+}
+
+
+//sp_trilinear_worldcoord_backward2feature_kernel<TILE_DIM_X, TILE_DIM_Y, DUP_X><<<dimGrid, dimBlock, 0, stream>>>(
+//        in_feature_grad.data_ptr<float>(),
+//        out_feature_grad.data_ptr<float>(), in_feature.data_ptr<float>(), in_corner_lut.data_ptr<int32_t>(), in_worldcoord.data_ptr<float>(), p
+// Backward to feature
+template <int TILE_DIM_X, int TILE_DIM_Y, int DUP_X>
+__global__ void sp_trilinear_worldcoord_backward2feature_kernel(
+    float* __restrict__ in_feature_grad,
+    const float* __restrict__ out_feature_grad, const int32_t* __restrict__ corner_lut_t, const float* __restrict__ in_worldcoord, SpTrilinear_wc_Params p) {
+
+    const int GRID_X = gridDim.x;
+    int idx_entry = blockIdx.y * TILE_DIM_Y + threadIdx.y;
+
+    // Index processing
+    //int index[7];
+    int t = idx_entry;
+    int idx_in_worldcoord = 0;
+    int idx_out_feature = 0;
+    for (int i=p.in_worldcoord_ndim-2; i>=0; i--) {
+        int idx_t = t % p.in_worldcoord_dims[i];
+        t = t / p.in_worldcoord_dims[i];
+        //index[i] = idx_t;
+        idx_in_worldcoord += p.in_worldcoord_strides[i] * idx_t;
+        idx_out_feature += p.out_feature_strides[i] * idx_t;
+    }
+    if (t > 0) {
+        return;
+    }
+    int stride_in_worldcoord = p.in_worldcoord_strides[p.in_worldcoord_ndim-1];
+    int stride_out_feature = p.out_feature_strides[p.in_worldcoord_ndim-1];
+
+    float world_coords[3];
+    world_coords[0] = in_worldcoord[idx_in_worldcoord];
+    world_coords[1] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord];
+    world_coords[2] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord*2];
+
+    float local_coords[3];
+    int vox_coords[3];
+    local_coords[0] = world_coords[0] - floorf(world_coords[0]);
+    vox_coords[0] = (int)floorf(world_coords[0]);
+    local_coords[1] = world_coords[1] - floorf(world_coords[1]);
+    vox_coords[1] = (int)floorf(world_coords[1]);
+    local_coords[2] = world_coords[2] - floorf(world_coords[2]);
+    vox_coords[2] = (int)floorf(world_coords[2]);
+
+    float interp_weight[8];
+    // 0,0,0
+    interp_weight[0] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]);
+    // 0,0,1
+    interp_weight[1] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]);
+    // 0,1,0
+    interp_weight[2] = (1.0f-local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]);
+    // 0,1,1
+    interp_weight[3] = (1.0f-local_coords[0])*(local_coords[1])*(local_coords[2]);
+    // 1,0,0
+    interp_weight[4] = (local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]);
+    // 1,0,1
+    interp_weight[5] = (local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]);
+    // 1,1,0
+    interp_weight[6] = (local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]);
+    // 1,1,1
+    interp_weight[7] = (local_coords[0])*(local_coords[1])*(local_coords[2]);
+
+    int indices[8];
+    // Hard boundary check (zero padding)
+    if (isnan(world_coords[0]) || isnan(world_coords[1]) || isnan(world_coords[2])) {// ||
+        //vox_coords[0] < 0 || vox_coords[0] >= (p.corner_lut_dims[0]-1) ||
+        //vox_coords[1] < 0 || vox_coords[1] >= (p.corner_lut_dims[1]-1) ||
+        //vox_coords[2] < 0 || vox_coords[2] >= (p.corner_lut_dims[2]-1)) {
+        indices[0] = -1;indices[1] = -1;indices[2] = -1;indices[3] = -1;
+        indices[4] = -1;indices[5] = -1;indices[6] = -1;indices[7] = -1;
+    } else {
+        // Clamp to boundaries
+        int vox_coords_1[3];
+        vox_coords_1[0] = min(max(vox_coords[0]+1, 0), p.corner_lut_dims[0]-1);
+        vox_coords_1[1] = min(max(vox_coords[1]+1, 0), p.corner_lut_dims[1]-1);
+        vox_coords_1[2] = min(max(vox_coords[2]+1, 0), p.corner_lut_dims[2]-1);
+        vox_coords[0] = min(max(vox_coords[0], 0), p.corner_lut_dims[0]-1);
+        vox_coords[1] = min(max(vox_coords[1], 0), p.corner_lut_dims[1]-1);
+        vox_coords[2] = min(max(vox_coords[2], 0), p.corner_lut_dims[2]-1);
+        int idx_corner_lut;
+        // 000
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+                         p.corner_lut_strides[1] * vox_coords[1] +
+                         p.corner_lut_strides[2] * vox_coords[2];
+        indices[0] = corner_lut_t[idx_corner_lut];
+        // 001
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+                         p.corner_lut_strides[1] * vox_coords[1] +
+                         p.corner_lut_strides[2] * vox_coords_1[2];
+        indices[1] = corner_lut_t[idx_corner_lut];
+        // 010
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+                         p.corner_lut_strides[1] * vox_coords_1[1] +
+                         p.corner_lut_strides[2] * vox_coords[2];
+        indices[2] = corner_lut_t[idx_corner_lut];
+        // 011
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+                         p.corner_lut_strides[1] * vox_coords_1[1] +
+                         p.corner_lut_strides[2] * vox_coords_1[2];
+        indices[3] = corner_lut_t[idx_corner_lut];
+        // 100
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+                         p.corner_lut_strides[1] * vox_coords[1] +
+                         p.corner_lut_strides[2] * vox_coords[2];
+        indices[4] = corner_lut_t[idx_corner_lut];
+        // 101
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+                         p.corner_lut_strides[1] * vox_coords[1] +
+                         p.corner_lut_strides[2] * vox_coords_1[2];
+        indices[5] = corner_lut_t[idx_corner_lut];
+        // 110
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+                         p.corner_lut_strides[1] * vox_coords_1[1] +
+                         p.corner_lut_strides[2] * vox_coords[2];
+        indices[6] = corner_lut_t[idx_corner_lut];
+        // 111
+        idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+                         p.corner_lut_strides[1] * vox_coords_1[1] +
+                         p.corner_lut_strides[2] * vox_coords_1[2];
+        indices[7] = corner_lut_t[idx_corner_lut];
+    }
+
+    if (p.ign_zero) {
+#pragma unroll
+        for (int i=0; i<8; i++) {
+            indices[i] -= 1;
+        }
+    }
+
+    //int idx_feat = blockIdx.x * TILE_DIM_X * DUP_X + threadIdx.x;
+    int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x;
+    for (int i=0; i<DUP_X; i++) {
+        if (idx_feat >= p.in_feature_dim) {
+            return;
+        }
+        float grad = out_feature_grad[idx_out_feature+stride_out_feature*idx_feat];
+#pragma unroll
+        for (int j=0; j<8; j++) {
+            if (indices[j] >= 0) {
+                //indices[j]*p.in_feature_dim+idx_feat
+                atomicAdd(&in_feature_grad[indices[j]*p.in_feature_dim+idx_feat], grad * interp_weight[j]);
+            }
+        }
+        //idx_feat += TILE_DIM_X;
+        idx_feat += TILE_DIM_X * GRID_X;
+    }
+}
+
+// in_feature, corner_lut_t, in_world_coord, ign_zero=False
+// Input:
+//      in_feature:     float32 [M C]
+//      in_corner_lut:  int32   [X Y Z]
+//      in_worldcoord:  float32 [..., 3]
+//      ---Index:          int32   [..., 8], containing [0, M]. 0 is ignore label.
+//      ---Coord:          float32 [..., 3]
+// Output:
+//      Interp. Feat:   float32 [..., C]
+// std::vector<torch::Tensor>
+torch::Tensor sp_trilinear_worldcoord_cuda(const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, int channel_pos) {
+    CHECK_CUDA(in_feature);
+    CHECK_CUDA(in_corner_lut);
+    CHECK_CUDA(in_worldcoord);
+
+    int curDevice = -1;
+    cudaGetDevice(&curDevice);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+    torch::Device device = in_feature.device();
+
+    // assert(tensor.sizes() == std::vector<int64_t>{3, 4, 5});
+    assert(in_feature.dtype() == torch::kFloat32);
+    assert(in_feature.dim() == 2);
+    assert(in_corner_lut.dtype() == torch::kInt32);
+    assert(in_corner_lut.dim() == 3);
+    assert(in_worldcoord.dtype() == torch::kFloat32);
+    assert(in_worldcoord.size(-1) == 3);
+    assert(in_worldcoord.dim() <= 8);
+
+    CHECK_CONTIGUOUS(in_feature);
+    //CHECK_CONTIGUOUS(in_corner_lut); // Will still run correctly, but performance will suffer.
+    //CHECK_CONTIGUOUS(in_worldcoord);
+
+    //int channel_pos = -1; // -1 for HWC, -3 for CHW
+    if (channel_pos < 0) {
+        channel_pos += in_worldcoord.dim();
+    }
+    assert(channel_pos >= 0 && channel_pos < in_worldcoord.dim());
+
+    SpTrilinear_wc_Params p;
+    p.in_feature_dim = in_feature.size(1);
+    p.in_feature_numentries = in_feature.size(0);
+    p.in_worldcoord_ndim = in_worldcoord.dim();
+    for (int i=0; i<in_worldcoord.dim(); i++) {
+        p.in_worldcoord_dims[i] = in_worldcoord.size(i);
+        p.in_worldcoord_strides[i] = in_worldcoord.stride(i);
+    }
+    p.ign_zero = ign_zero;
+
+    p.corner_lut_dims[0] = in_corner_lut.size(0);
+    p.corner_lut_dims[1] = in_corner_lut.size(1);
+    p.corner_lut_dims[2] = in_corner_lut.size(2);
+    p.corner_lut_strides[0] = in_corner_lut.stride(0);
+    p.corner_lut_strides[1] = in_corner_lut.stride(1);
+    p.corner_lut_strides[2] = in_corner_lut.stride(2);
+
+    int numentries = in_worldcoord.numel() / 3;
+    //printf("FWD numentries: %d\n", numentries);
+
+    std::vector<int64_t> out_feature_shape;
+    //if (channel_first) { // Channel First format, suitable for 2D convolution
+    //    //assert(false);
+    for (int i=0; i<channel_pos; i++) {
+        out_feature_shape.push_back(in_worldcoord.size(i));
+    }
+    out_feature_shape.push_back(p.in_feature_dim);
+    for (int i=channel_pos; i<in_worldcoord.dim()-1; i++) {
+        out_feature_shape.push_back(in_worldcoord.size(i));
+    }
+    torch::Tensor out_feature = torch::empty(out_feature_shape, torch::TensorOptions().dtype(torch::kFloat32).device(device));
+    // The feature is always at the last dimension. Swap it to the last dim.
+    for (int i=channel_pos+1; i<out_feature.dim(); i++) {
+        out_feature.transpose_(i-1, i);
+    }
+    //} else { // Channel Last
+    //    for (int i=0; i<in_worldcoord.dim()-1; i++) {
+    //        out_feature_shape.push_back(in_worldcoord.size(i));
+    //    }
+    //    out_feature_shape.push_back(p.in_feature_dim);
+    //    out_feature = torch::empty(out_feature_shape, torch::TensorOptions().dtype(torch::kFloat32).device(device));
+    //}
+    for (int i=0; i<out_feature.dim(); i++) {
+        p.out_feature_dims[i] = out_feature.size(i);
+        p.out_feature_strides[i] = out_feature.stride(i);
+    }
+
+    const int TILE_DIM_X = 16;  // feature dim
+    const int TILE_DIM_Y = 64;  // entry dim
+    const int DUP_X = 4;   // To amortize the cost of weight computation
+    //dim3 dimGrid((p.in_feature_dim+(TILE_DIM_X*DUP_X)-1)/(TILE_DIM_X*DUP_X), (numentries+TILE_DIM_Y-1)/TILE_DIM_Y, 1);
+    dim3 dimGrid((numentries+TILE_DIM_Y-1)/TILE_DIM_Y, (p.in_feature_dim+(TILE_DIM_X*DUP_X)-1)/(TILE_DIM_X*DUP_X), 1);
+    dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1);
+
+    sp_trilinear_worldcoord_kernel<TILE_DIM_X, TILE_DIM_Y, DUP_X><<<dimGrid, dimBlock, 0, stream>>>(
+        out_feature.data_ptr<float>(),
+        in_feature.data_ptr<float>(), in_corner_lut.data_ptr<int32_t>(), in_worldcoord.data_ptr<float>(), p
+    );
+    THCudaCheck(cudaGetLastError());
+    return out_feature;
+}
+
+
+// Backward function for sparse trilinear interpolation
+// Input:
+//      out_feature_grad:   float32 [..., C]
+//      in_feature:         float32 [M, C]
+//      in_corner_lut:      int32   [X Y Z]
+//      ---in_index:        int32   [..., 8], containing [0, M]. 0 is ignore label.
+//      in_worldcoord:      float32 [..., 3]
+//      ign_zero:           bool
+//      need_coord_grad:    bool
+// Output:
+//      in_feature_grad:    float32 [M, C]
+//      in_coord_grad:      float32 [..., 3]
+std::vector<torch::Tensor> sp_trilinear_worldcoord_backward_cuda(const torch::Tensor& out_feature_grad , const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, bool need_coord_grad) {
+    assert(need_coord_grad == false);
+    CHECK_CUDA(out_feature_grad);
+    CHECK_CUDA(in_feature);
+    CHECK_CUDA(in_corner_lut);
+    CHECK_CUDA(in_worldcoord);
+
+    int curDevice = -1;
+    cudaGetDevice(&curDevice);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+    torch::Device device = out_feature_grad.device();
+
+    //for (int i=0; i<out_feature_grad.dim(); i++) {
+    //    printf("[sp_trilinear_backward_cuda] dim, size, stride: %d, %d, %d\n", i, out_feature_grad.size(i), out_feature_grad.stride(i));
+    //}
+    //CHECK_CONTIGUOUS(out_feature_grad);
+    CHECK_CONTIGUOUS(in_feature);
+    //CHECK_CONTIGUOUS(in_worldcoord);
+
+    // assert(tensor.sizes() == std::vector<int64_t>{3, 4, 5});
+    assert(out_feature_grad.dtype() == torch::kFloat32);
+    for (int i=0; i<out_feature_grad.dim()-1; i++) {
+        assert(out_feature_grad.size(i) == in_worldcoord.size(i));
+    }
+    assert(out_feature_grad.size(-1) == in_feature.size(1));
+    assert(in_feature.dtype() == torch::kFloat32);
+    assert(in_feature.dim() == 2);
+    assert(in_worldcoord.dtype() == torch::kFloat32);
+    assert(in_worldcoord.size(-1) == 3);
+
+    SpTrilinear_wc_Params p;
+    p.in_feature_dim = in_feature.size(1);
+    p.in_feature_numentries = in_feature.size(0);
+    p.in_worldcoord_ndim = in_worldcoord.dim();
+    for (int i=0; i<in_worldcoord.dim(); i++) {
+        p.in_worldcoord_dims[i] = in_worldcoord.size(i);
+        p.in_worldcoord_strides[i] = in_worldcoord.stride(i);
+    }
+    p.ign_zero = ign_zero;
+
+    p.corner_lut_dims[0] = in_corner_lut.size(0);
+    p.corner_lut_dims[1] = in_corner_lut.size(1);
+    p.corner_lut_dims[2] = in_corner_lut.size(2);
+    p.corner_lut_strides[0] = in_corner_lut.stride(0);
+    p.corner_lut_strides[1] = in_corner_lut.stride(1);
+    p.corner_lut_strides[2] = in_corner_lut.stride(2);
+
+    for (int i=0; i<out_feature_grad.dim(); i++) {
+        p.out_feature_dims[i] = out_feature_grad.size(i);
+        p.out_feature_strides[i] = out_feature_grad.stride(i);
+    }
+    int numentries = in_worldcoord.numel() / 3;
+
+    // Create output tensors
+    torch::Tensor in_feature_grad = torch::zeros({p.in_feature_numentries, p.in_feature_dim}, torch::TensorOptions().dtype(torch::kFloat32).device(device));
+
+    torch::Tensor in_coord_grad;
+
+    {
+        const int TILE_DIM_X = 16;  // feature dim
+        const int TILE_DIM_Y = 64;  // entry dim
+        const int DUP_X = 4;   // To amortize the cost of weight computation
+        dim3 dimGrid((p.in_feature_dim+(TILE_DIM_X*DUP_X)-1)/(TILE_DIM_X*DUP_X), (numentries+TILE_DIM_Y-1)/TILE_DIM_Y, 1);
+        dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1);
+        //printf("BW dimGrid: %d, %d, %d \n", dimGrid.x, dimGrid.y, dimGrid.z);
+        sp_trilinear_worldcoord_backward2feature_kernel<TILE_DIM_X, TILE_DIM_Y, DUP_X><<<dimGrid, dimBlock, 0, stream>>>(
+            in_feature_grad.data_ptr<float>(),
+            out_feature_grad.data_ptr<float>(), in_corner_lut.data_ptr<int32_t>(), in_worldcoord.data_ptr<float>(), p
+        );
+    }
+
+    THCudaCheck(cudaGetLastError());
+    return {in_feature_grad};
+}
diff --git a/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp b/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..70095052d71f53e5e519a5f57b3c0848998a1b22
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp
@@ -0,0 +1,31 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+#include <torch/extension.h>
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+#include <vector>
+
+// Fast voxel traversal along rays
+std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(const torch::Tensor& in_voxel, const torch::Tensor& cam_ori, const torch::Tensor& cam_dir, const torch::Tensor& cam_up, float cam_f, const std::vector<float>& cam_c, const std::vector<int>& img_dims, int max_samples);
+
+
+// World Coordinate Sparse Trilinear Interpolation
+torch::Tensor sp_trilinear_worldcoord_cuda(const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, int channel_pos);
+
+std::vector<torch::Tensor> sp_trilinear_worldcoord_backward_cuda(const torch::Tensor& out_feature_grad , const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, bool need_coord_grad);
+
+// Fast & Memory Efficient Positional Encoding
+torch::Tensor positional_encoding_cuda(const torch::Tensor& in_feature, int ndegrees, int dim, bool incl_orig);
+
+torch::Tensor positional_encoding_backward_cuda(const torch::Tensor& out_feature_grad, const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig);
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("ray_voxel_intersection_perspective", &ray_voxel_intersection_perspective_cuda, "Ray-voxel intersections given perspective camera parameters (CUDA)");
+    m.def("sp_trilinear_worldcoord", &sp_trilinear_worldcoord_cuda, "Sparse Trilinear interpolation, world coordinate [forward] (CUDA)");
+    m.def("sp_trilinear_worldcoord_backward", &sp_trilinear_worldcoord_backward_cuda, "Sparse Trilinear interpolation, world coordinate [backward] (CUDA)");
+    m.def("positional_encoding", &positional_encoding_cuda, "Fused Positional Encoding [forward] (CUDA)");
+    m.def("positional_encoding_backward", &positional_encoding_backward_cuda, "Fused Positional Encoding [backward] (CUDA)");
+}
\ No newline at end of file
diff --git a/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h b/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h
new file mode 100644
index 0000000000000000000000000000000000000000..46b47fc80ecf802347607395ff04565732a4ee87
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h
@@ -0,0 +1,76 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+#ifndef VOXLIB_COMMON_H
+#define VOXLIB_COMMON_H
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+#define CHECK_CPU(x) TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor")
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+// CUDA vector math functions
+__host__ __device__ __forceinline__ int floor_div(int a, int b) {
+    int c = a / b;
+
+    if (c * b > a) {
+        c--;
+    }
+
+    return c;
+}
+
+template <typename scalar_t>
+__host__ __forceinline__ void cross(scalar_t* r, const scalar_t* a, const scalar_t* b) {
+    r[0] = a[1]*b[2] - a[2]*b[1];
+    r[1] = a[2]*b[0] - a[0]*b[2];
+    r[2] = a[0]*b[1] - a[1]*b[0];
+}
+
+__device__ __host__ __forceinline__ float dot(const float* a, const float* b) {
+    return a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
+}
+
+template <typename scalar_t, int ndim>
+__device__ __host__ __forceinline__ void copyarr(scalar_t* r, const scalar_t* a) {
+    #pragma unroll
+    for (int i=0; i<ndim; i++) {
+        r[i] = a[i];
+    }
+}
+
+// TODO: use rsqrt to speed up
+// inplace version
+template <typename scalar_t, int ndim>
+__device__ __host__ __forceinline__ void normalize(scalar_t* a) {
+    scalar_t vec_len=0.0f;
+    #pragma unroll
+    for (int i=0; i<ndim; i++) {
+        vec_len += a[i]*a[i];
+    }
+    vec_len = sqrtf(vec_len);
+    #pragma unroll
+    for (int i=0; i<ndim; i++) {
+        a[i] /= vec_len;
+    }
+}
+
+// normalize + copy
+template <typename scalar_t, int ndim>
+__device__ __host__ __forceinline__ void normalize(scalar_t* r, const scalar_t* a) {
+    scalar_t vec_len=0.0f;
+    #pragma unroll
+    for (int i=0; i<ndim; i++) {
+        vec_len += a[i]*a[i];
+    }
+    vec_len = sqrtf(vec_len);
+    #pragma unroll
+    for (int i=0; i<ndim; i++) {
+        r[i] = a[i] / vec_len;
+    }
+}
+
+#endif // VOXLIB_COMMON_H
\ No newline at end of file
diff --git a/imaginaire/model_utils/label.py b/imaginaire/model_utils/label.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd6c76c44e19e1184959a19c1c8ab10ca8c93369
--- /dev/null
+++ b/imaginaire/model_utils/label.py
@@ -0,0 +1,99 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+
+
+def make_one_hot(cfg, is_inference, data):
+    r"""Convert appropriate image data types to one-hot representation.
+
+    Args:
+        data (dict): Dict containing data_type as key, with each value
+            as a list of torch.Tensors.
+    Returns:
+        data (dict): same as input data, but with one-hot for selected
+        types.
+    """
+    assert hasattr(cfg, 'one_hot_num_classes')
+    num_classes = getattr(cfg, 'one_hot_num_classes')
+    use_dont_care = getattr(cfg, 'use_dont_care', False)
+    for data_type, data_type_num_classes in num_classes.items():
+        if data_type in data.keys():
+            data[data_type] = _encode_onehot(data[data_type] * 255.0, data_type_num_classes, use_dont_care).float()
+    return data
+
+
+def concat_labels(cfg, is_inference, data):
+    assert hasattr(cfg, 'input_labels')
+    input_labels = getattr(cfg, 'input_labels')
+    dataset_type = getattr(cfg, 'type')
+
+    # Package output.
+    labels = []
+    for data_type in input_labels:
+        label = data.pop(data_type)
+        labels.append(label)
+    if not ('video' in dataset_type):
+        data['label'] = torch.cat(labels, dim=0)
+    else:
+        data['label'] = torch.cat(labels, dim=1)
+    return data
+
+
+def concat_few_shot_labels(cfg, is_inference, data):
+    assert hasattr(cfg, 'input_few_shot_labels')
+    input_labels = getattr(cfg, 'input_few_shot_labels')
+    dataset_type = getattr(cfg, 'type')
+
+    # Package output.
+    labels = []
+    for data_type in input_labels:
+        label = data.pop(data_type)
+        labels.append(label)
+    if not ('video' in dataset_type):
+        data['few_shot_label'] = torch.cat(labels, dim=0)
+    else:
+        data['few_shot_label'] = torch.cat(labels, dim=1)
+    return data
+
+
+def move_dont_care(cfg, is_inference, data):
+    assert hasattr(cfg, 'move_dont_care')
+    move_dont_care = getattr(cfg, 'move_dont_care')
+    for data_type, data_type_num_classes in move_dont_care.items():
+        label_map = data[data_type] * 255.0
+        label_map[label_map < 0] = data_type_num_classes
+        label_map[label_map >= data_type_num_classes] = data_type_num_classes
+        data[data_type] = label_map / 255.0
+    return data
+
+
+def _encode_onehot(label_map, num_classes, use_dont_care):
+    r"""Make input one-hot.
+
+    Args:
+        label_map (torch.Tensor): (C, H, W) tensor containing indices.
+        num_classes (int): Number of labels to expand tensor to.
+        use_dont_care (bool): Use the dont care label or not?
+    Returns:
+        output (torch.Tensor): (num_classes, H, W) one-hot tensor.
+    """
+    # All labels lie in [0. num_classes - 1].
+    # Encode dont care as num_classes.
+    label_map[label_map < 0] = num_classes
+    label_map[label_map >= num_classes] = num_classes
+
+    size = label_map.size()
+    output_size = (num_classes + 1, size[-2], size[-1])
+    output = torch.zeros(*output_size)
+    if label_map.dim() == 4:
+        output = output.unsqueeze(0).repeat(label_map.size(0), 1, 1, 1)
+        output = output.scatter_(1, label_map.data.long(), 1.0)
+        if not use_dont_care:
+            output = output[:, :num_classes, ...]
+    else:
+        output = output.scatter_(0, label_map.data.long(), 1.0)
+        if not use_dont_care:
+            output = output[:num_classes, ...]
+    return output
diff --git a/imaginaire/model_utils/pix2pixHD.py b/imaginaire/model_utils/pix2pixHD.py
new file mode 100644
index 0000000000000000000000000000000000000000..862eb591a665e3ea81b3f918696feac3640a6c94
--- /dev/null
+++ b/imaginaire/model_utils/pix2pixHD.py
@@ -0,0 +1,227 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""Utils for the pix2pixHD model."""
+import numpy as np
+import torch
+
+from imaginaire.utils.data import get_paired_input_label_channel_number
+from imaginaire.utils.distributed import dist_all_gather_tensor, is_master
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.trainer import (get_optimizer, get_optimizer_for_params,
+                                      wrap_model_and_optimizer)
+from sklearn.cluster import KMeans
+
+
+def cluster_features(cfg, train_data_loader, net_E,
+                     preprocess=None, small_ratio=0.0625, is_cityscapes=True):
+    r"""Use clustering to compute the features.
+
+    Args:
+        cfg (obj): Global configuration file.
+        train_data_loader (obj): Dataloader for iterate through the training
+            set.
+        net_E (nn.Module): Pytorch network.
+        preprocess (function): Pre-processing function.
+        small_ratio (float): We only consider instance that at least occupy
+            $(small_ratio) amount of image space.
+        is_cityscapes (bool): Is this is the cityscape dataset? In the
+            Cityscapes dataset, the instance labels for car start with 26001,
+            26002, ...
+
+    Returns:
+        ( num_labels x num_cluster_centers x feature_dims): cluster centers.
+    """
+    # Encode features.
+    label_nc = get_paired_input_label_channel_number(cfg.data)
+    feat_nc = cfg.gen.enc.num_feat_channels
+    n_clusters = getattr(cfg.gen.enc, 'num_clusters', 10)
+    # Compute features.
+    features = {}
+    for label in range(label_nc):
+        features[label] = np.zeros((0, feat_nc + 1))
+    for data in train_data_loader:
+        if preprocess is not None:
+            data = preprocess(data)
+        feat = encode_features(net_E, feat_nc, label_nc,
+                               data['images'], data['instance_maps'],
+                               is_cityscapes)
+        # We only collect the feature vectors for the master GPU.
+        if is_master():
+            for label in range(label_nc):
+                features[label] = np.append(
+                    features[label], feat[label], axis=0)
+    # Clustering.
+    # We only perform clustering for the master GPU.
+    if is_master():
+        for label in range(label_nc):
+            feat = features[label]
+            # We only consider segments that are greater than a pre-set
+            # threshold.
+            feat = feat[feat[:, -1] > small_ratio, :-1]
+            if feat.shape[0]:
+                n_clusters = min(feat.shape[0], n_clusters)
+                kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
+                n, d = kmeans.cluster_centers_.shape
+                this_cluster = getattr(net_E, 'cluster_%d' % label)
+                this_cluster[0:n, :] = torch.Tensor(
+                    kmeans.cluster_centers_).float()
+
+
+def encode_features(net_E, feat_nc, label_nc, image, inst,
+                    is_cityscapes=True):
+    r"""Compute feature embeddings for an image image.
+    TODO(Ting-Chun): To make this funciton dataset independent.
+
+    Args:
+        net_E (nn.Module): The encoder network.
+        feat_nc (int): Feature dimensions
+        label_nc (int): Number of segmentation labels.
+        image (tensor): Input image tensor.
+        inst (tensor): Input instance map.
+        is_cityscapes (bool): Is this is the cityscape dataset? In the
+            Cityscapes dataset, the instance labels for car start with 26001,
+            26002, ...
+    Returns:
+        (list of list of numpy vectors): We will have $(label_nc)
+            list. For each list, it will record a list of feature vectors of
+            dimension $(feat_nc+1) where the first $(feat_nc) dimensions is
+            the representative feature of an instance and the last dimension
+            is the proportion.
+    """
+    # h, w = inst.size()[2:]
+    feat_map = net_E(image, inst)
+    feature_map_gather = dist_all_gather_tensor(feat_map)
+    inst_gathered = dist_all_gather_tensor(inst)
+    # Initialize the cluster centers.
+    # For each feature vector,
+    #   0:feat_nc will be the feature vector.
+    #   The feat_nc dimension record the percentage of the instance.
+    feature = {}
+    for i in range(label_nc):
+        feature[i] = np.zeros((0, feat_nc + 1))
+    if is_master():
+        all_feat_map = torch.cat(feature_map_gather, 0)
+        all_inst_map = torch.cat(inst_gathered, 0)
+        # Scan through the batches.
+        for n in range(all_feat_map.size()[0]):
+            feat_map = all_feat_map[n:(n + 1), :, :, :]
+            inst = all_inst_map[n:(n + 1), :, :, :]
+            fh, fw = feat_map.size()[2:]
+            inst_np = inst.cpu().numpy().astype(int)
+            for i in np.unique(inst_np):
+                if is_cityscapes:
+                    label = i if i < 1000 else i // 1000
+                else:
+                    label = i
+                idx = (inst == int(i)).nonzero()
+                num = idx.size()[0]
+                # We will just pick the middle pixel as its representative
+                # feature.
+                idx = idx[num // 2, :]
+                val = np.zeros((1, feat_nc + 1))
+                for k in range(feat_nc):
+                    # We expect idx[0]=0 and idx[1]=0 as the number of sample
+                    # per processing is 1 (idx[0]=0) and the channel number of
+                    # the instance map is 1.
+                    val[0, k] = feat_map[
+                        idx[0], idx[1] + k, idx[2], idx[3]].item()
+                val[0, feat_nc] = float(num) / (fh * fw)
+                feature[label] = np.append(feature[label], val, axis=0)
+        return feature
+    else:
+        return feature
+
+
+def get_edges(t):
+    r""" Compute edge maps for a given input instance map.
+
+    Args:
+        t (4D tensor): Input instance map.
+    Returns:
+        (4D tensor): Output edge map.
+    """
+    edge = torch.cuda.ByteTensor(t.size()).zero_()
+    edge[:, :, :, 1:] = edge[:, :, :, 1:] | (
+        t[:, :, :, 1:] != t[:, :, :, :-1]).byte()
+    edge[:, :, :, :-1] = edge[:, :, :, :-1] | (
+        t[:, :, :, 1:] != t[:, :, :, :-1]).byte()
+    edge[:, :, 1:, :] = edge[:, :, 1:, :] | (
+        t[:, :, 1:, :] != t[:, :, :-1, :]).byte()
+    edge[:, :, :-1, :] = edge[:, :, :-1, :] | (
+        t[:, :, 1:, :] != t[:, :, :-1, :]).byte()
+    return edge.float()
+
+
+def get_train_params(net, param_names_start_with=[], param_names_include=[]):
+    r"""Get train parameters.
+
+    Args:
+        net (obj): Network object.
+        param_names_start_with (list of strings): Params whose names
+            start with any of the strings will be trained.
+        param_names_include (list of strings): Params whose names include
+            any of the strings will be trained.
+    """
+    params_to_train = []
+    params_dict = net.state_dict()
+    list_of_param_names_to_train = set()
+    # Iterate through all params in the network and check if we need to
+    # train it.
+    for key, value in params_dict.items():
+        do_train = False
+        # If the param name starts with the target string (excluding
+        # the 'module' part etc), we will train this param.
+        key_s = key.replace('module.', '').replace('averaged_model.', '')
+        for param_name in param_names_start_with:
+            if key_s.startswith(param_name):
+                do_train = True
+                list_of_param_names_to_train.add(param_name)
+
+        # Otherwise, if the param name includes the target string,
+        # we will also train it.
+        if not do_train:
+            for param_name in param_names_include:
+                if param_name in key_s:
+                    do_train = True
+                    full_param_name = \
+                        key_s[:(key_s.find(param_name) + len(param_name))]
+                    list_of_param_names_to_train.add(full_param_name)
+
+        # If we decide to train the param, add it to the list to train.
+        if do_train:
+            module = net
+            key_list = key.split('.')
+            for k in key_list:
+                module = getattr(module, k)
+            params_to_train += [module]
+
+    print('Training layers: ', sorted(list_of_param_names_to_train))
+    return params_to_train
+
+
+def get_optimizer_with_params(cfg, net_G, net_D, param_names_start_with=[],
+                              param_names_include=[]):
+    r"""Return the optimizer object.
+
+    Args:
+        cfg (obj): Global config.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        param_names_start_with (list of strings): Params whose names
+            start with any of the strings will be trained.
+        param_names_include (list of strings): Params whose names include
+            any of the strings will be trained.
+    """
+    # If any of the param name lists is not empty, will only train
+    # these params. Otherwise will train the entire network (all params).
+    if param_names_start_with or param_names_include:
+        params = get_train_params(net_G, param_names_start_with,
+                                  param_names_include)
+    else:
+        params = net_G.parameters()
+
+    opt_G = get_optimizer_for_params(cfg.gen_opt, params)
+    opt_D = get_optimizer(cfg.dis_opt, net_D)
+    return wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D)
diff --git a/imaginaire/model_utils/rename_inputs.py b/imaginaire/model_utils/rename_inputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..f40b3f98f6bf21f9efb21c9b7cd99226adbb2769
--- /dev/null
+++ b/imaginaire/model_utils/rename_inputs.py
@@ -0,0 +1,15 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+
+def rename_inputs(cfg, is_inference, data):
+    assert hasattr(cfg, 'rename_inputs')
+    attr = getattr(cfg, 'rename_inputs')
+    for key in attr.keys():
+        value = attr[key]
+        data[key] = data[value]
+        # Delete the old key.
+        del data[value]
+    return data
diff --git a/imaginaire/model_utils/wc_vid2vid/render.py b/imaginaire/model_utils/wc_vid2vid/render.py
new file mode 100644
index 0000000000000000000000000000000000000000..304b1cb3b58384ad8fb99bacf43ddda0bc0b2ff6
--- /dev/null
+++ b/imaginaire/model_utils/wc_vid2vid/render.py
@@ -0,0 +1,199 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import pickle
+import time
+
+import numpy as np
+
+
+class SplatRenderer(object):
+    """Splatting 3D point cloud into image using precomputed mapping."""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        """Reset the renderer."""
+        # 1 = point seen before, 0 = not seen.
+        # This is numpy uint8 array of size (N, 1)
+        self.seen_mask = None
+
+        # Time of first colorization of 3D point.
+        # This is numpy uint16 array of size (N, 1)
+        self.seen_time = None
+
+        # colors[kp_idx] is color of kp_idx'th keypoint.
+        # This is a numpy uint8 array of size (N, 3)
+        self.colors = None
+
+        self.time_taken = 0
+        self.call_idx = 0
+
+    def num_points(self):
+        r"""Number of points with assigned colors."""
+        return np.sum(self.seen_mask)
+
+    def _resize_arrays(self, max_point_idx):
+        r"""Makes arrays bigger, if needed.
+        Args:
+            max_point_idx (int): Highest 3D point index seen so far.
+        """
+        if self.colors is None:
+            old_max_point_idx = 0
+        else:
+            old_max_point_idx = self.colors.shape[0]
+
+        if max_point_idx > old_max_point_idx:
+            # Init new bigger arrays.
+            colors = np.zeros((max_point_idx, 3), dtype=np.uint8)
+            seen_mask = np.zeros((max_point_idx, 1), dtype=np.uint8)
+            seen_time = np.zeros((max_point_idx, 1), dtype=np.uint16)
+            # Copy old colors, if exist.
+            if old_max_point_idx > 0:
+                colors[:old_max_point_idx] = self.colors
+                seen_mask[:old_max_point_idx] = self.seen_mask
+                seen_time[:old_max_point_idx] = self.seen_time
+            # Reset pointers.
+            self.colors = colors
+            self.seen_mask = seen_mask
+            self.seen_time = seen_time
+
+    def update_point_cloud(self, image, point_info):
+        r"""Updates point cloud with new points and colors.
+        Args:
+            image (H x W x 3, uint8): Select colors from this image to assign to
+            3D points which do not have previously assigned colors.
+            point_info (N x 3): (i, j, 3D point idx) per row containing
+            mapping of image pixel to 3D point in point cloud.
+        """
+        if point_info is None or len(point_info) == 0:
+            return
+
+        start = time.time()
+        self.call_idx += 1
+
+        i_idxs = point_info[:, 0]
+        j_idxs = point_info[:, 1]
+        point_idxs = point_info[:, 2]
+
+        # Allocate memory for new colors.
+        max_point_idx = np.max(np.array(point_idxs)) + 1
+        self._resize_arrays(max_point_idx)
+        # print('max point idx:', max_point_idx)
+
+        # Save only the new colors.
+        self.colors[point_idxs] = \
+            self.seen_mask[point_idxs] * self.colors[point_idxs] + \
+            (1 - self.seen_mask[point_idxs]) * image[i_idxs, j_idxs]
+
+        # Save point seen times.
+        self.seen_time[point_idxs] = \
+            self.seen_mask[point_idxs] * self.seen_time[point_idxs] + \
+            (1 - self.seen_mask[point_idxs]) * self.call_idx
+
+        # Update seen point mask.
+        self.seen_mask[point_idxs] = 1
+
+        end = time.time()
+        self.time_taken += (end - start)
+
+    def render_image(self, point_info, w, h, return_mask=False):
+        r"""Creates image of (h, w) and fills in colors.
+        Args:
+            point_info (N x 3): (i, j, 3D point idx) per row containing
+            mapping of image pixel to 3D point in point cloud.
+            w (int): Width of output image.
+            h (int): Height of output image.
+            return_mask (bool): Return binary mask of coloring.
+        Returns:
+            (tuple):
+              - output (H x W x 3, uint8): Image formed with mapping and colors.
+              - mask (H x W x 1, uint8): Binary (255 or 0) mask of colorization.
+        """
+        output = np.zeros((h, w, 3), dtype=np.uint8)
+        mask = np.zeros((h, w, 1), dtype=np.uint8)
+
+        if point_info is None or len(point_info) == 0:
+            if return_mask:
+                return output, mask
+            else:
+                return output
+
+        start = time.time()
+
+        i_idxs = point_info[:, 0]
+        j_idxs = point_info[:, 1]
+        point_idxs = point_info[:, 2]
+
+        # Allocate memory for new colors.
+        max_point_idx = np.max(np.array(point_idxs)) + 1
+        self._resize_arrays(max_point_idx)
+
+        # num_found = np.sum(self.seen_mask[point_idxs])
+        # print('Found %d points to color' % (num_found))
+
+        # Copy colors.
+        output[i_idxs, j_idxs] = self.colors[point_idxs]
+
+        end = time.time()
+        self.time_taken += (end - start)
+
+        if return_mask:
+            mask[i_idxs, j_idxs] = 255 * self.seen_mask[point_idxs]
+            return output, mask
+        else:
+            return output
+
+
+def decode_unprojections(data):
+    r"""Unpickle unprojections and make array.
+    Args:
+        data (array of pickled info): Each pickled string has keypoint mapping
+        info.
+    Returns:
+        output (dict): Keys are the different resolutions, and values are padded
+        mapping information.
+    """
+
+    # Unpickle unprojections and store them in a dict with resolutions as keys.
+    all_unprojections = {}
+    for item in data:
+        info = pickle.loads(item)
+
+        for resolution, value in info.items():
+            if resolution not in all_unprojections:
+                all_unprojections[resolution] = []
+
+            if not value or value is None:
+                point_info = []
+            else:
+                point_info = value
+            all_unprojections[resolution].append(point_info)
+
+    outputs = {}
+    for resolution, values in all_unprojections.items():
+        # Get max length of mapping.
+        max_len = 0
+        for value in values:
+            max_len = max(max_len, len(value))
+            # Entries are a 3-tuple of (i_idx, j_idx, point_idx).
+            assert len(value) % 3 == 0
+
+        # Pad each mapping to max_len.
+        values = [
+            value +  # Original info.
+            [-1] * (max_len - len(value)) +  # Padding.
+            [len(value) // 3] * 3  # End sentinel with length.
+            for value in values
+        ]
+
+        # Convert each mapping to numpy and reshape.
+        values = [np.array(value).reshape(-1, 3) for value in values]
+
+        # Stack and put in output.
+        # Shape is (T, N, 3). T is time steps, N is num mappings.
+        outputs[resolution] = np.stack(values, axis=0)
+
+    return outputs
diff --git a/imaginaire/optimizers/__init__.py b/imaginaire/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bedc71589e178b96730174c9860b8cc8430e55
--- /dev/null
+++ b/imaginaire/optimizers/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .fromage import Fromage
+from .madam import Madam
+
+__all__ = ['Fromage', 'Madam']
diff --git a/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc b/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4a0a8977871d0626a62bfdd2561625a4fc3b5d7
Binary files /dev/null and b/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc b/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f8740580fc34277745504381d463ec83bd9d493
Binary files /dev/null and b/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc differ
diff --git a/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc b/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a384b1798545cd7b6d9b3515e1afc2b7fd892f68
Binary files /dev/null and b/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc differ
diff --git a/imaginaire/optimizers/fromage.py b/imaginaire/optimizers/fromage.py
new file mode 100644
index 0000000000000000000000000000000000000000..d00203de89f55fd122f71b7de8718ed7ef681ec8
--- /dev/null
+++ b/imaginaire/optimizers/fromage.py
@@ -0,0 +1,44 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# import torch
+import math
+
+from torch.optim.optimizer import Optimizer, required
+
+
+class Fromage(Optimizer):
+    r"""Fromage optimizer implementation (https://arxiv.org/abs/2002.03432)"""
+
+    def __init__(self, params, lr=required, momentum=0):
+        if lr is not required and lr < 0.0:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        defaults = dict(lr=lr, momentum=momentum)
+        super(Fromage, self).__init__(params, defaults)
+
+    def step(self, closure=None):
+        r"""Performs a single optimization step.
+
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                d_p = p.grad.data
+                d_p_norm = p.grad.norm()
+                p_norm = p.norm()
+                if p_norm > 0.0 and d_p_norm > 0.0:
+                    p.data.add_(-group['lr'], d_p * (p_norm / d_p_norm))
+                else:
+                    p.data.add_(-group['lr'], d_p)
+                p.data /= math.sqrt(1 + group['lr'] ** 2)
+
+        return loss
diff --git a/imaginaire/optimizers/madam.py b/imaginaire/optimizers/madam.py
new file mode 100644
index 0000000000000000000000000000000000000000..11bf71d049d9323e9ba646713413578ae5eb4503
--- /dev/null
+++ b/imaginaire/optimizers/madam.py
@@ -0,0 +1,54 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch.optim.optimizer import Optimizer, required
+
+
+class Madam(Optimizer):
+    r"""MADAM optimizer implementation (https://arxiv.org/abs/2006.14560)"""
+    def __init__(self, params, lr=required, scale=3.0,
+                 g_bound=None, momentum=0):
+        self.scale = scale
+        self.g_bound = g_bound
+        defaults = dict(lr=lr, momentum=momentum)
+        super(Madam, self).__init__(params, defaults)
+
+    def step(self, closure=None):
+        r"""Performs a single optimization step.
+
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+
+                state = self.state[p]
+                if len(state) == 0:
+                    state['max'] = self.scale * (p * p).mean().sqrt().item()
+                    state['step'] = 0
+                    state['exp_avg_sq'] = torch.zeros_like(p)
+
+                state['step'] += 1
+                bias_correction = 1 - 0.999 ** state['step']
+                state['exp_avg_sq'] = 0.999 * state[
+                    'exp_avg_sq'] + 0.001 * p.grad.data ** 2
+                g_normed = \
+                    p.grad.data / (state['exp_avg_sq'] / bias_correction).sqrt()
+                g_normed[torch.isnan(g_normed)] = 0
+                if self.g_bound is not None:
+                    g_normed.clamp_(-self.g_bound, self.g_bound)
+
+                p.data *= torch.exp(
+                    -group['lr'] * g_normed * torch.sign(p.data))
+                p.data.clamp_(-state['max'], state['max'])
+
+        return loss
diff --git a/imaginaire/third_party/__init__.py b/imaginaire/third_party/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc b/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5c112b422e5e97bfa8f09f096fbac9b1e6bae54
Binary files /dev/null and b/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/third_party/bias_act/__init__.py b/imaginaire/third_party/bias_act/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dfe0aec6e5fd4a1538ed959abff7c5106784c9b
--- /dev/null
+++ b/imaginaire/third_party/bias_act/__init__.py
@@ -0,0 +1,3 @@
+from .bias_act import FusedNonlinearity
+
+__all__ = ['FusedNonlinearity']
diff --git a/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc b/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..11e3a34fdbfff20d3f67e073c1b6048c72f57216
Binary files /dev/null and b/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc b/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d99f8e6245f42b7450319d0f9bd9974ee53b5402
Binary files /dev/null and b/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc differ
diff --git a/imaginaire/third_party/bias_act/bias_act.py b/imaginaire/third_party/bias_act/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..29b01dc97884036aec1c42feb184c510c5ad0870
--- /dev/null
+++ b/imaginaire/third_party/bias_act/bias_act.py
@@ -0,0 +1,219 @@
+# flake8: noqa
+import numpy as np
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+
+import bias_act_cuda
+
+# ----------------------------------------------------------------------------
+
+activation_funcs = {
+    'linear': SimpleNamespace(func=lambda x, **_: x, def_alpha=0, def_gain=1,
+                              cuda_idx=1, ref='', has_2nd_grad=False),
+    'relu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.relu(x),
+                            def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2,
+                            ref='y', has_2nd_grad=False),
+    'leakyrelu': SimpleNamespace(
+        func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
+        def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y',
+        has_2nd_grad=False),
+    'tanh': SimpleNamespace(func=lambda x, **_: torch.tanh(x), def_alpha=0,
+                            def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+    'sigmoid': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x),
+                               def_alpha=0, def_gain=1, cuda_idx=5, ref='y',
+                               has_2nd_grad=True),
+    'elu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.elu(x),
+                           def_alpha=0, def_gain=1, cuda_idx=6, ref='y',
+                           has_2nd_grad=True),
+    'selu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.selu(x),
+                            def_alpha=0, def_gain=1, cuda_idx=7, ref='y',
+                            has_2nd_grad=True),
+    'softplus': SimpleNamespace(
+        func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0,
+        def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+    'swish': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x) * x,
+                             def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9,
+                             ref='x', has_2nd_grad=True),
+}
+
+# ----------------------------------------------------------------------------
+
+_null_tensor = torch.empty([0])
+
+
+def _bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
+              impl='cuda'):
+    assert isinstance(x, torch.Tensor)
+    assert impl in ['ref', 'cuda']
+    if impl == 'cuda' and x.device.type == 'cuda':
+        return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain,
+                              clamp=clamp).apply(x, b)
+    return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain,
+                         clamp=clamp)
+
+
+# ----------------------------------------------------------------------------
+
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+    assert isinstance(x, torch.Tensor)
+    assert clamp is None or clamp >= 0
+    spec = activation_funcs[act]
+    alpha = float(alpha if alpha is not None else spec.def_alpha)
+    gain = float(gain if gain is not None else spec.def_gain)
+    clamp = float(clamp if clamp is not None else -1)
+
+    # Add bias.
+    if b is not None:
+        assert isinstance(b, torch.Tensor) and b.ndim == 1
+        assert 0 <= dim < x.ndim
+        assert b.shape[0] == x.shape[dim]
+        x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+    # Evaluate activation function.
+    alpha = float(alpha)
+    x = spec.func(x, alpha=alpha)
+
+    # Scale by gain.
+    gain = float(gain)
+    if gain != 1:
+        x = x * gain
+
+    # Clamp.
+    if clamp >= 0:
+        x = x.clamp(-clamp, clamp)  # pylint: disable=invalid-unary-operand-type
+    return x
+
+
+# ----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+    """Fast CUDA implementation of `bias_act()` using custom ops.
+    """
+    # Parse arguments.
+    assert clamp is None or clamp >= 0
+    spec = activation_funcs[act]
+    alpha = float(alpha if alpha is not None else spec.def_alpha)
+    gain = float(gain if gain is not None else spec.def_gain)
+    clamp = float(clamp if clamp is not None else -1)
+
+    # Lookup from cache.
+    key = (dim, act, alpha, gain, clamp)
+    if key in _bias_act_cuda_cache:
+        return _bias_act_cuda_cache[key]
+
+    # Forward op.
+    class BiasActCuda(torch.autograd.Function):
+        @staticmethod
+        def forward(ctx, x, b):  # pylint: disable=arguments-differ
+            if x.ndim > 2 and x.stride()[1] == 1:
+                ctx.memory_format = torch.channels_last
+            else:
+                ctx.memory_format = torch.contiguous_format
+            x = x.contiguous(memory_format=ctx.memory_format)
+            b = b.contiguous() if b is not None else _null_tensor
+            y = x
+            if act != 'linear' or gain != 1 or clamp >= 0 or b is not \
+                    _null_tensor:
+                y = bias_act_cuda.bias_act_cuda(x, b, _null_tensor, _null_tensor,
+                                                _null_tensor, 0, dim, spec.cuda_idx, alpha,
+                                                gain, clamp)
+            ctx.save_for_backward(
+                x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+                b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+                y if 'y' in spec.ref else _null_tensor)
+            return y
+
+        @staticmethod
+        def backward(ctx, dy):  # pylint: disable=arguments-differ
+            dy = dy.contiguous(memory_format=ctx.memory_format)
+            x, b, y = ctx.saved_tensors
+            dx = None
+            db = None
+
+            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+                dx = dy
+                if act != 'linear' or gain != 1 or clamp >= 0:
+                    dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+            if ctx.needs_input_grad[1]:
+                db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+            return dx, db
+
+    # Backward op.
+    class BiasActCudaGrad(torch.autograd.Function):
+        @staticmethod
+        def forward(ctx, dy, x, b, y):  # pylint: disable=arguments-differ
+            if x.ndim > 2 and x.stride()[1] == 1:
+                ctx.memory_format = torch.channels_last
+            else:
+                ctx.memory_format = torch.contiguous_format
+            dx = bias_act_cuda.bias_act_cuda(dy, b, x, y, _null_tensor, 1, dim,
+                                             spec.cuda_idx, alpha, gain, clamp)
+            ctx.save_for_backward(
+                dy if spec.has_2nd_grad else _null_tensor,
+                x, b, y)
+            return dx
+
+        @staticmethod
+        def backward(ctx, d_dx):  # pylint: disable=arguments-differ
+            d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+            dy, x, b, y = ctx.saved_tensors
+            d_dy = None
+            d_x = None
+            d_b = None
+            d_y = None
+
+            if ctx.needs_input_grad[0]:
+                d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+            if spec.has_2nd_grad and (
+                    ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+                d_x = bias_act_cuda.bias_act_cuda(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx,
+                                                  alpha, gain, clamp)
+
+            if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+                d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+            return d_dy, d_x, d_b, d_y
+
+    # Add to cache.
+    _bias_act_cuda_cache[key] = BiasActCuda
+    return BiasActCuda
+
+
+class FusedNonlinearity(nn.Module):
+    def __init__(self, nonlinearity, num_channels=None, lr_mul=1.0, alpha=None, impl='cuda', gain=None):
+        super().__init__()
+        if num_channels is not None:
+            self.bias = nn.Parameter(torch.zeros(num_channels))
+        else:
+            self.register_parameter('bias', None)
+        self.nonlinearity = nonlinearity
+        self.gain = gain
+        self.alpha = alpha
+        self.lr_mul = lr_mul
+        self.impl = impl
+
+    def forward(self, x):
+        bias = self.bias.type_as(x) * self.lr_mul if self.bias is not None else None
+        return _bias_act(
+            x, b=bias, dim=1, act=self.nonlinearity,
+            alpha=self.alpha, gain=self.gain, clamp=None, impl=self.impl
+        )
+
+    def __repr__(self):
+        mod_str = f'{self.__class__.__name__}(type={self.nonlinearity}'
+        if self.gain is not None:
+            mod_str += f', gain={self.gain}'
+        if self.alpha is not None:
+            mod_str += f', alpha={self.alpha}'
+        if self.lr_mul != 1:
+            mod_str += f', lr_mul={self.lr_mul}'
+        mod_str += ')'
+        return mod_str
diff --git a/imaginaire/third_party/bias_act/setup.py b/imaginaire/third_party/bias_act/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..13607b6083bd59eba24c5f4ac48a34048b55f642
--- /dev/null
+++ b/imaginaire/third_party/bias_act/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+    if cuda_version >= '11.0':
+        nvcc_args.append('-gencode')
+        nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+    name='bias_act_cuda',
+    py_modules=['bias_act'],
+    ext_modules=[
+        CUDAExtension('bias_act_cuda', [
+            './src/bias_act_cuda.cc',
+            './src/bias_act_cuda_kernel.cu'
+        ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+                               'nvcc': nvcc_args})
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda.cc b/imaginaire/third_party/bias_act/src/bias_act_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..cf975dbe6784e89cfa056574da8780d1e5f5b97d
--- /dev/null
+++ b/imaginaire/third_party/bias_act/src/bias_act_cuda.cc
@@ -0,0 +1,103 @@
+// Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include <torch/torch.h>
+#include <torch/extension.h>
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+#include "bias_act_cuda.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+    if (x.dim() != y.dim())
+        return false;
+    for (int64_t i = 0; i < x.dim(); i++)
+    {
+        if (x.size(i) != y.size(i))
+            return false;
+        if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+            return false;
+    }
+    return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act_cuda(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+    // Validate arguments.
+    TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+    TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+    TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+    TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+    TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+    TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+    TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+    TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+    TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+    TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+    // Validate layout.
+    TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+    TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+    TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+    TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+    TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+    // Create output tensor.
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+    torch::Tensor y = torch::empty_like(x);
+    TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+    // Initialize CUDA kernel parameters.
+    bias_act_kernel_params p;
+    p.x     = x.data_ptr();
+    p.b     = (b.numel()) ? b.data_ptr() : NULL;
+    p.xref  = (xref.numel()) ? xref.data_ptr() : NULL;
+    p.yref  = (yref.numel()) ? yref.data_ptr() : NULL;
+    p.dy    = (dy.numel()) ? dy.data_ptr() : NULL;
+    p.y     = y.data_ptr();
+    p.grad  = grad;
+    p.act   = act;
+    p.alpha = alpha;
+    p.gain  = gain;
+    p.clamp = clamp;
+    p.sizeX = (int)x.numel();
+    p.sizeB = (int)b.numel();
+    p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+    // Choose CUDA kernel.
+    void* kernel;
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda_kernel", [&]
+    {
+        kernel = choose_bias_act_kernel<scalar_t>(p);
+    });
+    TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+    // Launch CUDA kernel.
+    p.loopX = 4;
+    int blockSize = 4 * 32;
+    int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+    void* args[] = {&p};
+    AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+    return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+    m.def("bias_act_cuda", &bias_act_cuda);
+}
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda.h b/imaginaire/third_party/bias_act/src/bias_act_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4
--- /dev/null
+++ b/imaginaire/third_party/bias_act/src/bias_act_cuda.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+    const void* x;      // [sizeX]
+    const void* b;      // [sizeB] or NULL
+    const void* xref;   // [sizeX] or NULL
+    const void* yref;   // [sizeX] or NULL
+    const void* dy;     // [sizeX] or NULL
+    void*       y;      // [sizeX]
+
+    int         grad;
+    int         act;
+    float       alpha;
+    float       gain;
+    float       clamp;
+
+    int         sizeX;
+    int         sizeB;
+    int         stepB;
+    int         loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu b/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9adbb942b5ce5740a5527449995e1887cda12816
--- /dev/null
+++ b/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu
@@ -0,0 +1,173 @@
+// Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include <c10/util/Half.h>
+#include "bias_act_cuda.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template <class T> struct InternalType;
+template <> struct InternalType<double>     { typedef double scalar_t; };
+template <> struct InternalType<float>      { typedef float  scalar_t; };
+template <> struct InternalType<c10::Half>  { typedef float  scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template <class T, int A>
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+    typedef typename InternalType<T>::scalar_t scalar_t;
+    int G                 = p.grad;
+    scalar_t alpha        = (scalar_t)p.alpha;
+    scalar_t gain         = (scalar_t)p.gain;
+    scalar_t clamp        = (scalar_t)p.clamp;
+    scalar_t one          = (scalar_t)1;
+    scalar_t two          = (scalar_t)2;
+    scalar_t expRange     = (scalar_t)80;
+    scalar_t halfExpRange = (scalar_t)40;
+    scalar_t seluScale    = (scalar_t)1.0507009873554804934193349852946;
+    scalar_t seluAlpha    = (scalar_t)1.6732632423543772848170429916717;
+
+    // Loop over elements.
+    int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+    for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+    {
+        // Load.
+        scalar_t x = (scalar_t)((const T*)p.x)[xi];
+        scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+        scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+        scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+        scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+        scalar_t yy = (gain != 0) ? yref / gain : 0;
+        scalar_t y = 0;
+
+        // Apply bias.
+        ((G == 0) ? x : xref) += b;
+
+        // linear
+        if (A == 1)
+        {
+            if (G == 0) y = x;
+            if (G == 1) y = x;
+        }
+
+        // relu
+        if (A == 2)
+        {
+            if (G == 0) y = (x > 0) ? x : 0;
+            if (G == 1) y = (yy > 0) ? x : 0;
+        }
+
+        // lrelu
+        if (A == 3)
+        {
+            if (G == 0) y = (x > 0) ? x : x * alpha;
+            if (G == 1) y = (yy > 0) ? x : x * alpha;
+        }
+
+        // tanh
+        if (A == 4)
+        {
+            if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+            if (G == 1) y = x * (one - yy * yy);
+            if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+        }
+
+        // sigmoid
+        if (A == 5)
+        {
+            if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+            if (G == 1) y = x * yy * (one - yy);
+            if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+        }
+
+        // elu
+        if (A == 6)
+        {
+            if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+            if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+            if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+        }
+
+        // selu
+        if (A == 7)
+        {
+            if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+            if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+            if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+        }
+
+        // softplus
+        if (A == 8)
+        {
+            if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+            if (G == 1) y = x * (one - exp(-yy));
+            if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+        }
+
+        // swish
+        if (A == 9)
+        {
+            if (G == 0)
+                y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+            else
+            {
+                scalar_t c = exp(xref);
+                scalar_t d = c + one;
+                if (G == 1)
+                    y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+                else
+                    y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+                yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+            }
+        }
+
+        // Apply gain.
+        y *= gain * dy;
+
+        // Clamp.
+        if (clamp >= 0)
+        {
+            if (G == 0)
+                y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+            else
+                y = (yref > -clamp & yref < clamp) ? y : 0;
+        }
+
+        // Store.
+        ((T*)p.y)[xi] = (T)y;
+    }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+    if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
+    if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
+    if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
+    if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
+    if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
+    if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
+    if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
+    if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
+    if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
+    return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel<double>       (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel<float>        (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel<c10::Half>    (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/channelnorm/channelnorm.py b/imaginaire/third_party/channelnorm/channelnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdd46711ca0bf2b6bb650112fa364100f6d4c927
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/channelnorm.py
@@ -0,0 +1,39 @@
+# flake8: noqa
+from torch.autograd import Function, Variable
+from torch.nn.modules.module import Module
+import channelnorm_cuda
+
+
+class ChannelNormFunction(Function):
+    @staticmethod
+    def forward(ctx, input1, norm_deg=2):
+        assert input1.is_contiguous()
+        b, _, h, w = input1.size()
+        output = input1.new(b, 1, h, w).zero_()
+
+        channelnorm_cuda.forward(input1, output, norm_deg)
+        ctx.save_for_backward(input1, output)
+        ctx.norm_deg = norm_deg
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input1, output = ctx.saved_tensors
+
+        grad_input1 = Variable(input1.new(input1.size()).zero_())
+
+        channelnorm_cuda.backward(input1, output, grad_output.data,
+                                  grad_input1.data, ctx.norm_deg)
+
+        return grad_input1, None
+
+
+class ChannelNorm(Module):
+
+    def __init__(self, norm_deg=2):
+        super(ChannelNorm, self).__init__()
+        self.norm_deg = norm_deg
+
+    def forward(self, input1):
+        return ChannelNormFunction.apply(input1, self.norm_deg)
diff --git a/imaginaire/third_party/channelnorm/setup.py b/imaginaire/third_party/channelnorm/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..8503ad2b254915bbbab391eb31baf0dcdc9a6bd1
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+    if cuda_version >= '11.0':
+        nvcc_args.append('-gencode')
+        nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+    name='channelnorm_cuda',
+    py_modules=['channelnorm'],
+    ext_modules=[
+        CUDAExtension('channelnorm_cuda', [
+            './src/channelnorm_cuda.cc',
+            './src/channelnorm_kernel.cu'
+        ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+                               'nvcc': nvcc_args})
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc b/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..69d82eb184e97b2eefa9810ad156d1104cf84745
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc
@@ -0,0 +1,31 @@
+#include <torch/torch.h>
+#include <ATen/ATen.h>
+
+#include "channelnorm_kernel.cuh"
+
+int channelnorm_cuda_forward(
+    at::Tensor& input1, 
+    at::Tensor& output,
+    int norm_deg) {
+
+    channelnorm_kernel_forward(input1, output, norm_deg);
+    return 1;
+}
+
+
+int channelnorm_cuda_backward(
+    at::Tensor& input1, 
+    at::Tensor& output,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1,
+    int norm_deg) {
+
+    channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
+    return 1;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
+  m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
+}
+
diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..99ace6855a61373443a6ddff7c7858eb474e9e48
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu
@@ -0,0 +1,177 @@
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include "channelnorm_kernel.cuh"
+
+#define CUDA_NUM_THREADS 512 
+
+#define DIM0(TENSOR) ((TENSOR).x)
+#define DIM1(TENSOR) ((TENSOR).y)
+#define DIM2(TENSOR) ((TENSOR).z)
+#define DIM3(TENSOR) ((TENSOR).w)
+
+#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
+
+using at::Half;
+
+template <typename scalar_t>
+__global__ void kernel_channelnorm_update_output(
+    const int n, 
+    const scalar_t* __restrict__ input1,
+    const long4 input1_size,
+    const long4 input1_stride,
+    scalar_t* __restrict__ output, 
+    const long4 output_size,
+    const long4 output_stride,
+    int norm_deg) {
+
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (index >= n) {
+        return;
+    }
+
+    int dim_b = DIM0(output_size);
+    int dim_c = DIM1(output_size);
+    int dim_h = DIM2(output_size);
+    int dim_w = DIM3(output_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+    int i1dim_c = DIM1(input1_size);
+    int i1dim_h = DIM2(input1_size);
+    int i1dim_w = DIM3(input1_size);
+    int i1dim_chw = i1dim_c * i1dim_h * i1dim_w;
+    int i1dim_hw  = i1dim_h * i1dim_w;
+
+    float result = 0.0;
+
+    for (int c = 0; c < i1dim_c; ++c) {
+        int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x;
+        scalar_t val = input1[i1Index];
+        result += static_cast<float>(val * val);
+    }
+    result = sqrt(result);
+    output[index] = static_cast<scalar_t>(result);
+}
+
+
+template <typename scalar_t>
+__global__ void kernel_channelnorm_backward_input1(
+    const int n,
+    const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+    const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, 
+    const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+    scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, 
+    int norm_deg) {
+
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (index >= n) {
+        return;
+    }
+
+    float val = 0.0;
+
+    int dim_b = DIM0(gradInput_size);
+    int dim_c = DIM1(gradInput_size);
+    int dim_h = DIM2(gradInput_size);
+    int dim_w = DIM3(gradInput_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+    int dim_hw  = dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+
+    int outIndex = b * dim_hw + y * dim_w + x;
+    val = static_cast<float>(gradOutput[outIndex]) * static_cast<float>(input1[index]) / (static_cast<float>(output[outIndex])+1e-9);
+    gradInput[index] = static_cast<scalar_t>(val);
+
+}
+
+void channelnorm_kernel_forward(
+    at::Tensor& input1, 
+    at::Tensor& output, 
+    int norm_deg) {
+
+    const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+    const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+    const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+    const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+    int n = output.numel();
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] {
+
+      kernel_channelnorm_update_output<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+          n,
+          input1.data<scalar_t>(), 
+          input1_size,
+          input1_stride, 
+          output.data<scalar_t>(),
+          output_size,
+          output_stride, 
+          norm_deg);
+
+    }));
+
+      // TODO: ATen-equivalent check
+
+     // THCudaCheck(cudaGetLastError());
+}
+
+void channelnorm_kernel_backward(
+    at::Tensor& input1, 
+    at::Tensor& output,
+    at::Tensor& gradOutput, 
+    at::Tensor& gradInput1, 
+    int norm_deg) {
+
+    const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+    const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+    const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+    const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+    const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
+    const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
+
+    const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
+    const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
+
+    int n = gradInput1.numel();
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] {
+
+      kernel_channelnorm_backward_input1<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+          n, 
+          input1.data<scalar_t>(),
+          input1_size,
+          input1_stride,
+          output.data<scalar_t>(),
+          output_size,
+          output_stride,
+          gradOutput.data<scalar_t>(),
+          gradOutput_size,
+          gradOutput_stride, 
+          gradInput1.data<scalar_t>(),
+          gradInput1_size,
+          gradInput1_stride,
+          norm_deg
+    );
+
+    }));
+
+    // TODO: Add ATen-equivalent check
+
+//    THCudaCheck(cudaGetLastError());
+}
diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..3e6223f7fe60feb4bf9e4f66c3d849b84c89dcda
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <ATen/ATen.h>
+
+void channelnorm_kernel_forward(
+    at::Tensor& input1,
+    at::Tensor& output, 
+    int norm_deg);
+
+
+void channelnorm_kernel_backward(
+    at::Tensor& input1,
+    at::Tensor& output,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1,
+    int norm_deg);
diff --git a/imaginaire/third_party/correlation/correlation.py b/imaginaire/third_party/correlation/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..e47739dff7475c0f29bff32bc2dc9f097161d144
--- /dev/null
+++ b/imaginaire/third_party/correlation/correlation.py
@@ -0,0 +1,105 @@
+# flake8: noqa
+import torch
+from torch.nn.modules.module import Module
+from torch.autograd import Function
+import correlation_cuda
+
+
+class CorrelationFunction(Function):
+
+    @staticmethod
+    def forward(ctx,
+            pad_size,
+            kernel_size,
+            max_displacement,
+            stride1,
+            stride2,
+            corr_multiply,
+            input1,
+            input2):
+        ctx.save_for_backward(input1, input2)
+        ctx.pad_size = pad_size
+        ctx.kernel_size = kernel_size
+        ctx.max_displacement = max_displacement
+        ctx.stride1 = stride1
+        ctx.stride2 = stride2
+        ctx.corr_multiply = corr_multiply
+
+        with torch.cuda.device_of(input1):
+            rbot1 = input1.new()
+            rbot2 = input2.new()
+            output = input1.new()
+
+            correlation_cuda.forward(
+                input1,
+                input2,
+                rbot1,
+                rbot2,
+                output,
+                ctx.pad_size,
+                ctx.kernel_size,
+                ctx.max_displacement,
+                ctx.stride1,
+                ctx.stride2,
+                ctx.corr_multiply)
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input1, input2 = ctx.saved_tensors
+
+        with torch.cuda.device_of(input1):
+            rbot1 = input1.new()
+            rbot2 = input2.new()
+
+            grad_input1 = input1.new()
+            grad_input2 = input2.new()
+
+            correlation_cuda.backward(
+                input1,
+                input2,
+                rbot1,
+                rbot2,
+                grad_output,
+                grad_input1,
+                grad_input2,
+                ctx.pad_size,
+                ctx.kernel_size,
+                ctx.max_displacement,
+                ctx.stride1,
+                ctx.stride2,
+                ctx.corr_multiply)
+
+        return grad_input1, grad_input2
+
+class Correlation(Module):
+    def __init__(
+            self,
+            pad_size=0,
+            kernel_size=0,
+            max_displacement=0,
+            stride1=1,
+            stride2=2,
+            corr_multiply=1):
+        super(Correlation, self).__init__()
+        self.pad_size = pad_size
+        self.kernel_size = kernel_size
+        self.max_displacement = max_displacement
+        self.stride1 = stride1
+        self.stride2 = stride2
+        self.corr_multiply = corr_multiply
+
+    def forward(self, input1, input2):
+
+        result = CorrelationFunction.apply(
+            self.pad_size,
+            self.kernel_size,
+            self.max_displacement,
+            self.stride1,
+            self.stride2,
+            self.corr_multiply,
+            input1,
+            input2)
+
+        return result
diff --git a/imaginaire/third_party/correlation/setup.py b/imaginaire/third_party/correlation/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c02aacc53102a5ef534db1a7cd69c546004f268
--- /dev/null
+++ b/imaginaire/third_party/correlation/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+    if cuda_version >= '11.0':
+        nvcc_args.append('-gencode')
+        nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+    name='correlation_cuda',
+    py_modules=['correlation'],
+    ext_modules=[
+        CUDAExtension('correlation_cuda', [
+            './src/correlation_cuda.cc',
+            './src/correlation_cuda_kernel.cu'
+        ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+                               'nvcc': nvcc_args})
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/imaginaire/third_party/correlation/src/correlation_cuda.cc b/imaginaire/third_party/correlation/src/correlation_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..feccd65295fa90a22564b08fc80464a76361a1aa
--- /dev/null
+++ b/imaginaire/third_party/correlation/src/correlation_cuda.cc
@@ -0,0 +1,173 @@
+#include <torch/torch.h>
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <stdio.h>
+#include <iostream>
+
+#include "correlation_cuda_kernel.cuh"
+
+int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
+                       int pad_size,
+                       int kernel_size,
+                       int max_displacement,
+                       int stride1,
+                       int stride2,
+                       int corr_type_multiply)
+{
+
+  int batchSize = input1.size(0);
+
+  int nInputChannels = input1.size(1);
+  int inputHeight = input1.size(2);
+  int inputWidth = input1.size(3);
+
+  int kernel_radius = (kernel_size - 1) / 2;
+  int border_radius = kernel_radius + max_displacement;
+
+  int paddedInputHeight = inputHeight + 2 * pad_size;
+  int paddedInputWidth = inputWidth + 2 * pad_size;
+
+  int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
+
+  int outputHeight = ceil(static_cast<float>(paddedInputHeight - 2 * border_radius) / static_cast<float>(stride1));
+  int outputwidth = ceil(static_cast<float>(paddedInputWidth - 2 * border_radius) / static_cast<float>(stride1));
+
+  rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+  rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+  output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
+
+  rInput1.fill_(0);
+  rInput2.fill_(0);
+  output.fill_(0);
+
+  int success = correlation_forward_cuda_kernel(
+    output,
+    output.size(0), 
+    output.size(1),
+    output.size(2),
+    output.size(3),
+    output.stride(0),
+    output.stride(1),
+    output.stride(2),
+    output.stride(3),
+    input1,
+    input1.size(1),
+    input1.size(2),
+    input1.size(3),
+    input1.stride(0),
+    input1.stride(1),
+    input1.stride(2),
+    input1.stride(3),
+    input2,
+    input2.size(1),
+    input2.stride(0),
+    input2.stride(1),
+    input2.stride(2),
+    input2.stride(3),
+    rInput1,
+    rInput2,
+    pad_size,     
+    kernel_size,
+    max_displacement,
+    stride1,
+    stride2,
+    corr_type_multiply,
+	at::cuda::getCurrentCUDAStream()
+    //at::globalContext().getCurrentCUDAStream()
+  );
+
+  //check for errors
+  if (!success) {
+    AT_ERROR("CUDA call failed");
+  }
+
+  return 1;
+
+}
+
+int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 
+                       at::Tensor& gradInput1, at::Tensor& gradInput2,
+                       int pad_size,
+                       int kernel_size,
+                       int max_displacement,
+                       int stride1,
+                       int stride2,
+                       int corr_type_multiply)
+{
+
+  int batchSize = input1.size(0);
+  int nInputChannels = input1.size(1);
+  int paddedInputHeight = input1.size(2)+ 2 * pad_size;
+  int paddedInputWidth = input1.size(3)+ 2 * pad_size;
+
+  int height = input1.size(2);
+  int width = input1.size(3);
+
+  rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+  rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+  gradInput1.resize_({batchSize, nInputChannels, height, width});
+  gradInput2.resize_({batchSize, nInputChannels, height, width});
+
+  rInput1.fill_(0);
+  rInput2.fill_(0);
+  gradInput1.fill_(0);
+  gradInput2.fill_(0);
+
+  int success = correlation_backward_cuda_kernel(gradOutput,
+                                                gradOutput.size(0),
+                                                gradOutput.size(1),
+                                                gradOutput.size(2),
+                                                gradOutput.size(3),
+                                                gradOutput.stride(0),
+                                                gradOutput.stride(1),
+                                                gradOutput.stride(2),
+                                                gradOutput.stride(3),
+                                                input1,
+                                                input1.size(1),
+                                                input1.size(2),
+                                                input1.size(3),
+                                                input1.stride(0),
+                                                input1.stride(1),
+                                                input1.stride(2),
+                                                input1.stride(3),
+                                                input2,  
+                                                input2.stride(0),
+                                                input2.stride(1),
+                                                input2.stride(2),
+                                                input2.stride(3),
+                                                gradInput1,
+                                                gradInput1.stride(0),
+                                                gradInput1.stride(1),
+                                                gradInput1.stride(2),
+                                                gradInput1.stride(3),
+                                                gradInput2,
+                                                gradInput2.size(1),
+                                                gradInput2.stride(0),
+                                                gradInput2.stride(1),
+                                                gradInput2.stride(2),
+                                                gradInput2.stride(3),
+                                                rInput1,
+                                                rInput2,
+                                                pad_size,
+                                                kernel_size,
+                                                max_displacement,
+                                                stride1, 
+                                                stride2,
+                                                corr_type_multiply,
+												at::cuda::getCurrentCUDAStream()
+                                                //at::globalContext().getCurrentCUDAStream()
+                                               );
+
+  if (!success) {
+    AT_ERROR("CUDA call failed");
+  }
+
+  return 1;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
+  m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
+}
+
diff --git a/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..eaf86fc129137d055de7400916567c6669b45c19
--- /dev/null
+++ b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu
@@ -0,0 +1,564 @@
+#include <stdio.h>
+
+#include "correlation_cuda_kernel.cuh"
+
+#define CUDA_NUM_THREADS 1024
+#define THREADS_PER_BLOCK 32
+#define FULL_MASK 0xffffffff
+
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Dispatch.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+
+using at::Half;
+
+template<typename scalar_t>
+__forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) {
+        for (int offset = 16; offset > 0; offset /= 2)
+                val += __shfl_down_sync(FULL_MASK, val, offset);
+        return val;
+}
+
+template<typename scalar_t>
+__forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) {
+
+        static __shared__ scalar_t shared[32];
+        int lane = threadIdx.x % warpSize;
+        int wid = threadIdx.x / warpSize;
+
+        val = warpReduceSum(val);
+
+        if (lane == 0)
+                shared[wid] = val;
+
+        __syncthreads();
+
+        val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
+
+        if (wid == 0)
+                val = warpReduceSum(val);
+
+        return val;
+}
+
+
+template <typename scalar_t>
+__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
+{
+
+    // n (batch size), c (num of channels), y (height), x (width)
+    int n = blockIdx.x;
+    int y = blockIdx.y;
+    int x = blockIdx.z;
+
+    int ch_off = threadIdx.x;
+    scalar_t value;
+
+    int dimcyx = channels * height * width;
+    int dimyx = height * width;
+
+    int p_dimx = (width + 2 * pad_size);
+    int p_dimy = (height + 2 * pad_size);
+    int p_dimyxc = channels * p_dimy * p_dimx;
+    int p_dimxc = p_dimx * channels;
+
+    for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
+      value = input[n * dimcyx + c * dimyx + y * width + x];
+      rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
+    }
+}
+
+
+template<typename scalar_t>
+__global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels,
+                const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1,
+                const int nInputChannels, const int inputHeight, const int inputWidth,
+                const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size,
+                const int max_displacement, const int stride1, const int stride2) {
+
+        int32_t pInputWidth = inputWidth + 2 * pad_size;
+        int32_t pInputHeight = inputHeight + 2 * pad_size;
+
+        int32_t kernel_rad = (kernel_size - 1) / 2;
+
+        int32_t displacement_rad = max_displacement / stride2;
+
+        int32_t displacement_size = 2 * displacement_rad + 1;
+
+        int32_t n = blockIdx.x;
+        int32_t y1 = blockIdx.y * stride1 + max_displacement;
+        int32_t x1 = blockIdx.z * stride1 + max_displacement;
+        int32_t c = threadIdx.x;
+
+        int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+
+        int32_t pdimxc = pInputWidth * nInputChannels;
+
+        int32_t pdimc = nInputChannels;
+
+        int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth;
+        int32_t tdimyx = outputHeight * outputWidth;
+        int32_t tdimx = outputWidth;
+
+        int32_t nelems = kernel_size * kernel_size * pdimc;
+
+        // element-wise product along channel axis
+        for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
+                for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
+                        int x2 = x1 + ti * stride2;
+                        int y2 = y1 + tj * stride2;
+
+                        float acc0 = 0.0f;
+
+                        for (int j = -kernel_rad; j <= kernel_rad; ++j) {
+                                for (int i = -kernel_rad; i <= kernel_rad; ++i) {
+                                        // THREADS_PER_BLOCK
+                                        #pragma unroll
+                                        for (int ch = c; ch < pdimc; ch += blockDim.x) {
+
+                                                int indx1 = n * pdimyxc + (y1 + j) * pdimxc
+                                                                + (x1 + i) * pdimc + ch;
+                                                int indx2 = n * pdimyxc + (y2 + j) * pdimxc
+                                                                + (x2 + i) * pdimc + ch;
+                                                acc0 += static_cast<float>(rInput1[indx1] * rInput2[indx2]);
+                                        }
+                                }
+                        }
+
+                        if (blockDim.x == warpSize) {
+                            __syncwarp();
+                            acc0 = warpReduceSum(acc0);
+                        } else {
+                            __syncthreads();
+                            acc0 = blockReduceSum(acc0);
+                        }
+
+                        if (threadIdx.x == 0) {
+
+                                int tc = (tj + displacement_rad) * displacement_size
+                                                + (ti + displacement_rad);
+                                const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx
+                                                + blockIdx.z;
+                                output[tindx] = static_cast<scalar_t>(acc0 / nelems);
+                        }
+            }
+        }
+}
+
+
+template <typename scalar_t>
+__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, 
+                                            const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 
+                                            const scalar_t* __restrict__ rInput2, 
+                                            int pad_size,
+                                            int kernel_size,
+                                            int max_displacement,
+                                            int stride1,
+                                            int stride2)
+  {
+    // n (batch size), c (num of channels), y (height), x (width)
+
+    int n = item; 
+    int y = blockIdx.x * stride1 + pad_size;
+    int x = blockIdx.y * stride1 + pad_size;
+    int c = blockIdx.z;
+    int tch_off = threadIdx.x;
+
+    int kernel_rad = (kernel_size - 1) / 2;
+    int displacement_rad = max_displacement / stride2;
+    int displacement_size = 2 * displacement_rad + 1;
+
+    int xmin = (x - kernel_rad - max_displacement) / stride1;
+    int ymin = (y - kernel_rad - max_displacement) / stride1;
+
+    int xmax = (x + kernel_rad - max_displacement) / stride1;
+    int ymax = (y + kernel_rad - max_displacement) / stride1;
+
+    if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
+        // assumes gradInput1 is pre-allocated and zero filled
+      return;
+    }
+
+    if (xmin > xmax || ymin > ymax) {
+        // assumes gradInput1 is pre-allocated and zero filled
+        return;
+    }
+
+    xmin = max(0,xmin);
+    xmax = min(outputWidth-1,xmax);
+
+    ymin = max(0,ymin);
+    ymax = min(outputHeight-1,ymax);
+
+    int pInputWidth = inputWidth + 2 * pad_size;
+    int pInputHeight = inputHeight + 2 * pad_size;
+
+    int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+    int pdimxc = pInputWidth * nInputChannels;
+    int pdimc = nInputChannels;
+
+    int tdimcyx = nOutputChannels * outputHeight * outputWidth;
+    int tdimyx = outputHeight * outputWidth;
+    int tdimx = outputWidth;
+
+    int odimcyx = nInputChannels * inputHeight* inputWidth;
+    int odimyx = inputHeight * inputWidth;
+    int odimx = inputWidth;
+
+    scalar_t nelems = kernel_size * kernel_size * nInputChannels;
+
+    __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
+    prod_sum[tch_off] = 0;
+
+    for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
+
+      int i2 = (tc % displacement_size - displacement_rad) * stride2;
+      int j2 = (tc / displacement_size - displacement_rad) * stride2;
+
+      int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
+      
+      scalar_t val2 = rInput2[indx2];
+
+      for (int j = ymin; j <= ymax; ++j) {
+        for (int i = xmin; i <= xmax; ++i) {
+          int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
+          prod_sum[tch_off] += gradOutput[tindx] * val2;
+        }
+      }
+    }
+    __syncthreads();
+
+    if(tch_off == 0) {
+      scalar_t reduce_sum = 0;
+      for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
+          reduce_sum += prod_sum[idx];
+      }
+      const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
+      gradInput1[indx1] = reduce_sum / nelems;
+    }
+
+}
+
+template <typename scalar_t>
+__global__ void correlation_backward_input2(int item, scalar_t*  gradInput2, int nInputChannels, int inputHeight, int inputWidth,
+                                            const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
+                                            const scalar_t* __restrict__ rInput1,
+                                            int pad_size,
+                                            int kernel_size,
+                                            int max_displacement,
+                                            int stride1,
+                                            int stride2)
+{
+    // n (batch size), c (num of channels), y (height), x (width)
+
+    int n = item;
+    int y = blockIdx.x * stride1 + pad_size;
+    int x = blockIdx.y * stride1 + pad_size;
+    int c = blockIdx.z;
+
+    int tch_off = threadIdx.x;
+
+    int kernel_rad = (kernel_size - 1) / 2;
+    int displacement_rad = max_displacement / stride2;
+    int displacement_size = 2 * displacement_rad + 1;
+
+    int pInputWidth = inputWidth + 2 * pad_size;
+    int pInputHeight = inputHeight + 2 * pad_size;
+
+    int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+    int pdimxc = pInputWidth * nInputChannels;
+    int pdimc = nInputChannels;
+
+    int tdimcyx = nOutputChannels * outputHeight * outputWidth;
+    int tdimyx = outputHeight * outputWidth;
+    int tdimx = outputWidth;
+
+    int odimcyx = nInputChannels * inputHeight* inputWidth;
+    int odimyx = inputHeight * inputWidth;
+    int odimx = inputWidth;
+
+    scalar_t nelems = kernel_size * kernel_size * nInputChannels;
+
+    __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
+    prod_sum[tch_off] = 0;
+
+    for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
+      int i2 = (tc % displacement_size - displacement_rad) * stride2;
+      int j2 = (tc / displacement_size - displacement_rad) * stride2;
+
+      int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
+      int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
+
+      int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
+      int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
+
+      if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
+          // assumes gradInput2 is pre-allocated and zero filled
+        continue;
+      }
+
+      if (xmin > xmax || ymin > ymax) {
+          // assumes gradInput2 is pre-allocated and zero filled
+          continue;
+      }
+
+      xmin = max(0,xmin);
+      xmax = min(outputWidth-1,xmax);
+
+      ymin = max(0,ymin);
+      ymax = min(outputHeight-1,ymax);
+      
+      int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
+      scalar_t val1 = rInput1[indx1];
+
+      for (int j = ymin; j <= ymax; ++j) {
+        for (int i = xmin; i <= xmax; ++i) {
+          int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
+          prod_sum[tch_off] += gradOutput[tindx] * val1;
+        }
+      }
+    }
+
+    __syncthreads();
+
+    if(tch_off == 0) {
+      scalar_t reduce_sum = 0;
+      for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
+          reduce_sum += prod_sum[idx];
+      }
+      const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
+      gradInput2[indx2] = reduce_sum / nelems;
+    }
+
+}
+
+int correlation_forward_cuda_kernel(at::Tensor& output,
+                                    int ob,
+                                    int oc,
+                                    int oh,
+                                    int ow,
+                                    int osb,
+                                    int osc,
+                                    int osh,
+                                    int osw,
+
+                                    at::Tensor& input1,
+                                    int ic,
+                                    int ih,
+                                    int iw,
+                                    int isb,
+                                    int isc,
+                                    int ish,
+                                    int isw,
+
+                                    at::Tensor& input2,
+                                    int gc,
+                                    int gsb,
+                                    int gsc,
+                                    int gsh,
+                                    int gsw,
+
+                                    at::Tensor& rInput1,
+                                    at::Tensor& rInput2,
+                                    int pad_size,
+                                    int kernel_size,
+                                    int max_displacement,
+                                    int stride1,
+                                    int stride2,
+                                    int corr_type_multiply,
+                                    cudaStream_t stream) 
+{
+
+   int batchSize = ob;
+
+   int nInputChannels = ic;
+   int inputWidth = iw;
+   int inputHeight = ih;
+
+   int nOutputChannels = oc;
+   int outputWidth = ow;
+   int outputHeight = oh;
+
+   dim3 blocks_grid(batchSize, inputHeight, inputWidth);
+   dim3 threads_block(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
+
+  channels_first<scalar_t><<<blocks_grid,threads_block, 0, stream>>>(
+      input1.data<scalar_t>(), rInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size);
+
+  }));
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
+
+  channels_first<scalar_t><<<blocks_grid,threads_block, 0, stream>>> (
+      input2.data<scalar_t>(), rInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size);
+
+  }));
+
+   dim3 threadsPerBlock(THREADS_PER_BLOCK);
+   dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
+
+   correlation_forward<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>> 
+                        (output.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,
+                         rInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,
+                         rInput2.data<scalar_t>(),
+                         pad_size,
+                         kernel_size,
+                         max_displacement,
+                         stride1,
+                         stride2);
+
+  }));
+
+  cudaError_t err = cudaGetLastError();
+
+
+  // check for errors
+  if (err != cudaSuccess) {
+    printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
+    return 0;
+  }
+
+  return 1;
+}
+
+
+int correlation_backward_cuda_kernel(
+                                    at::Tensor& gradOutput,
+                                    int gob,
+                                    int goc,
+                                    int goh,
+                                    int gow,
+                                    int gosb,
+                                    int gosc,
+                                    int gosh,
+                                    int gosw,
+
+                                    at::Tensor& input1,
+                                    int ic,
+                                    int ih,
+                                    int iw,
+                                    int isb,
+                                    int isc,
+                                    int ish,
+                                    int isw,
+
+                                    at::Tensor& input2,
+                                    int gsb,
+                                    int gsc,
+                                    int gsh,
+                                    int gsw,
+
+                                    at::Tensor& gradInput1,
+                                    int gisb,
+                                    int gisc,
+                                    int gish,
+                                    int gisw,
+
+                                    at::Tensor& gradInput2,
+                                    int ggc,
+                                    int ggsb,
+                                    int ggsc,
+                                    int ggsh,
+                                    int ggsw,
+
+                                    at::Tensor& rInput1,
+                                    at::Tensor& rInput2,
+                                    int pad_size,
+                                    int kernel_size,
+                                    int max_displacement,
+                                    int stride1,
+                                    int stride2,
+                                    int corr_type_multiply,
+                                    cudaStream_t stream)
+{
+
+    int batchSize = gob;
+    int num = batchSize;
+
+    int nInputChannels = ic;
+    int inputWidth = iw;
+    int inputHeight = ih;
+
+    int nOutputChannels = goc;
+    int outputWidth = gow;
+    int outputHeight = goh;
+
+    dim3 blocks_grid(batchSize, inputHeight, inputWidth);
+    dim3 threads_block(THREADS_PER_BLOCK);
+
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
+
+        channels_first<scalar_t><<<blocks_grid, threads_block, 0, stream>>>(
+            input1.data<scalar_t>(),
+            rInput1.data<scalar_t>(),
+            nInputChannels,
+            inputHeight,
+            inputWidth,
+            pad_size
+        );
+    }));
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
+
+        channels_first<scalar_t><<<blocks_grid, threads_block, 0, stream>>>(
+            input2.data<scalar_t>(),
+            rInput2.data<scalar_t>(),
+            nInputChannels,
+            inputHeight,
+            inputWidth,
+            pad_size
+        );
+    }));
+
+    dim3 threadsPerBlock(THREADS_PER_BLOCK);
+    dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
+
+    for (int n = 0; n < num; ++n) {
+
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
+
+
+          correlation_backward_input1<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>> (
+              n, gradInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,
+              gradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,
+              rInput2.data<scalar_t>(),
+              pad_size,
+              kernel_size,
+              max_displacement,
+              stride1,
+              stride2);
+      }));
+    }
+
+    for(int n = 0; n < batchSize; n++) {
+
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
+
+        correlation_backward_input2<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>>(
+            n, gradInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,
+            gradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,
+            rInput1.data<scalar_t>(),
+            pad_size,
+            kernel_size,
+            max_displacement,
+            stride1,
+            stride2);
+
+        }));
+    }
+
+  // check for errors
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess) {
+    printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
+    return 0;
+  }
+
+  return 1;
+}
diff --git a/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1586d3af6bc184bfea8482a991a6625f865f02b3
--- /dev/null
+++ b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh
@@ -0,0 +1,91 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <cuda_runtime.h>
+
+int correlation_forward_cuda_kernel(at::Tensor& output,
+    int ob,
+    int oc,
+    int oh,
+    int ow,
+    int osb,
+    int osc,
+    int osh,
+    int osw,
+
+    at::Tensor& input1,
+    int ic,
+    int ih,
+    int iw,
+    int isb,
+    int isc,
+    int ish,
+    int isw,
+
+    at::Tensor& input2,
+    int gc,
+    int gsb,
+    int gsc,
+    int gsh,
+    int gsw,
+
+    at::Tensor& rInput1,
+    at::Tensor& rInput2,
+    int pad_size,
+    int kernel_size,
+    int max_displacement,
+    int stride1,
+    int stride2,
+    int corr_type_multiply,
+    cudaStream_t stream);
+
+
+int correlation_backward_cuda_kernel(   
+    at::Tensor& gradOutput,
+    int gob,
+    int goc,
+    int goh,
+    int gow,
+    int gosb,
+    int gosc,
+    int gosh,
+    int gosw,
+
+    at::Tensor& input1,
+    int ic,
+    int ih,
+    int iw,
+    int isb,
+    int isc,
+    int ish,
+    int isw,
+
+    at::Tensor& input2,
+    int gsb,
+    int gsc,
+    int gsh,
+    int gsw,
+
+    at::Tensor& gradInput1, 
+    int gisb,
+    int gisc,
+    int gish,
+    int gisw,
+
+    at::Tensor& gradInput2,
+    int ggc,
+    int ggsb,
+    int ggsc,
+    int ggsh,
+    int ggsw,
+
+    at::Tensor& rInput1,
+    at::Tensor& rInput2,
+    int pad_size,
+    int kernel_size,
+    int max_displacement,
+    int stride1,
+    int stride2,
+    int corr_type_multiply,
+    cudaStream_t stream);
diff --git a/imaginaire/third_party/flow_net/__init__.py b/imaginaire/third_party/flow_net/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imaginaire/third_party/flow_net/flow_net.py b/imaginaire/third_party/flow_net/flow_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..41759c50fa1389b6fbe2e5db725a4b71ff3f2342
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flow_net.py
@@ -0,0 +1,89 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import types
+from imaginaire.third_party.flow_net.flownet2 import models as \
+    flownet2_models
+from imaginaire.third_party.flow_net.flownet2.utils import tools \
+    as flownet2_tools
+from imaginaire.model_utils.fs_vid2vid import resample
+from imaginaire.utils.io import get_checkpoint
+
+
+class FlowNet(nn.Module):
+    def __init__(self, pretrained=True, fp16=False):
+        super().__init__()
+        flownet2_args = types.SimpleNamespace()
+        setattr(flownet2_args, 'fp16', fp16)
+        setattr(flownet2_args, 'rgb_max', 1.0)
+        if fp16:
+            print('FlowNet2 is running in fp16 mode.')
+        self.flowNet = flownet2_tools.module_to_dict(flownet2_models)[
+            'FlowNet2'](flownet2_args).to('cuda')
+        if pretrained:
+            flownet2_path = get_checkpoint('flownet2.pth.tar',
+                                           '1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da')
+            checkpoint = torch.load(flownet2_path,
+                                    map_location=torch.device('cpu'))
+            self.flowNet.load_state_dict(checkpoint['state_dict'])
+        self.flowNet.eval()
+
+    def forward(self, input_A, input_B):
+        size = input_A.size()
+        assert(len(size) == 4 or len(size) == 5 or len(size) == 6)
+        if len(size) >= 5:
+            if len(size) == 5:
+                b, n, c, h, w = size
+            else:
+                b, t, n, c, h, w = size
+            input_A = input_A.contiguous().view(-1, c, h, w)
+            input_B = input_B.contiguous().view(-1, c, h, w)
+            flow, conf = self.compute_flow_and_conf(input_A, input_B)
+            if len(size) == 5:
+                return flow.view(b, n, 2, h, w), conf.view(b, n, 1, h, w)
+            else:
+                return flow.view(b, t, n, 2, h, w), conf.view(b, t, n, 1, h, w)
+        else:
+            return self.compute_flow_and_conf(input_A, input_B)
+
+    def compute_flow_and_conf(self, im1, im2):
+        assert(im1.size()[1] == 3)
+        assert(im1.size() == im2.size())
+        old_h, old_w = im1.size()[2], im1.size()[3]
+        new_h, new_w = old_h // 64 * 64, old_w // 64 * 64
+        if old_h != new_h:
+            im1 = F.interpolate(im1, size=(new_h, new_w), mode='bilinear',
+                                align_corners=False)
+            im2 = F.interpolate(im2, size=(new_h, new_w), mode='bilinear',
+                                align_corners=False)
+        data1 = torch.cat([im1.unsqueeze(2), im2.unsqueeze(2)], dim=2)
+        with torch.no_grad():
+            flow1 = self.flowNet(data1)
+        # img_diff = torch.sum(abs(im1 - resample(im2, flow1)),
+        #                      dim=1, keepdim=True)
+        # conf = torch.clamp(1 - img_diff, 0, 1)
+
+        conf = (self.norm(im1 - resample(im2, flow1)) < 0.02).float()
+
+        # data2 = torch.cat([im2.unsqueeze(2), im1.unsqueeze(2)], dim=2)
+        # with torch.no_grad():
+        #     flow2 = self.flowNet(data2)
+        # warped_flow2 = resample(flow2, flow1)
+        # flow_sum = self.norm(flow1 + warped_flow2)
+        # disocc = flow_sum > (0.05 * (self.norm(flow1) +
+        # self.norm(warped_flow2)) + 0.5)
+        # conf = 1 - disocc.float()
+
+        if old_h != new_h:
+            flow1 = F.interpolate(flow1, size=(old_h, old_w), mode='bilinear',
+                                  align_corners=False) * old_h / new_h
+            conf = F.interpolate(conf, size=(old_h, old_w), mode='bilinear',
+                                 align_corners=False)
+        return flow1, conf
+
+    def norm(self, t):
+        return torch.sum(t * t, dim=1, keepdim=True)
diff --git a/imaginaire/third_party/flow_net/flownet2/models.py b/imaginaire/third_party/flow_net/flownet2/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0066464a01942c85909a7e5ddbc97e39f244623
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/models.py
@@ -0,0 +1,474 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch.nn import init
+import torch.nn as nn
+import resample2d
+import channelnorm
+import numpy as np
+from imaginaire.third_party.flow_net.flownet2.networks import flownet_c
+from imaginaire.third_party.flow_net.flownet2.networks import flownet_s
+from imaginaire.third_party.flow_net.flownet2.networks import flownet_sd
+from imaginaire.third_party.flow_net.flownet2.networks import flownet_fusion
+from imaginaire.third_party.flow_net.flownet2.networks.submodules import \
+    tofp16, tofp32
+'Parameter count = 162,518,834'
+
+
+class FlowNet2(nn.Module):
+    def __init__(self, args, use_batch_norm=False, div_flow=20.):
+        super(FlowNet2, self).__init__()
+        self.batch_norm = use_batch_norm
+        self.div_flow = div_flow
+        self.rgb_max = args.rgb_max
+        self.args = args
+        self.channelnorm = channelnorm.ChannelNorm()
+        # First Block (FlowNetC)
+        self.flownetc = flownet_c.FlowNetC(
+            args, use_batch_norm=self.batch_norm)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+        self.args = args
+        # if args.fp16:
+        #     self.resample1 = nn.Sequential(
+        #         tofp32(), resample2d.Resample2d(), tofp16())
+        # else:
+        self.resample1 = resample2d.Resample2d()
+        # Block (FlowNetS1)
+        self.flownets_1 = flownet_s.FlowNetS(
+            args, use_batch_norm=self.batch_norm)
+        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+        # if args.fp16:
+        #     self.resample2 = nn.Sequential(
+        #         tofp32(), resample2d.Resample2d(), tofp16())
+        # else:
+        self.resample2 = resample2d.Resample2d()
+        # Block (FlowNetS2)
+        self.flownets_2 = flownet_s.FlowNetS(
+            args, use_batch_norm=self.batch_norm)
+        # Block (FlowNetSD)
+        self.flownets_d = flownet_sd.FlowNetSD(
+            args, use_batch_norm=self.batch_norm)
+        self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest')
+        self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest')
+        # if args.fp16:
+        #     self.resample3 = nn.Sequential(
+        #         tofp32(), resample2d.Resample2d(), tofp16())
+        # else:
+        self.resample3 = resample2d.Resample2d()
+        # if args.fp16:
+        #     self.resample4 = nn.Sequential(
+        #         tofp32(), resample2d.Resample2d(), tofp16())
+        # else:
+        self.resample4 = resample2d.Resample2d()
+        # Block (FLowNetFusion)
+        self.flownetfusion = flownet_fusion.FlowNetFusion(
+            args, use_batch_norm=self.batch_norm)
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+    def init_deconv_bilinear(self, weight):
+        f_shape = weight.size()
+        height, width = f_shape[-2], f_shape[-1]
+        f = np.ceil(width / 2.0)
+        c = (2 * f - 1 - f % 2) / (2.0 * f)
+        bilinear = np.zeros([height, width])
+        for x in range(width):
+            for y in range(height):
+                value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
+                bilinear[x, y] = value
+        min_dim = min(f_shape[0], f_shape[1])
+        weight.data.fill_(0.)
+        for i in range(min_dim):
+            weight.data[i, i, :, :] = torch.from_numpy(bilinear)
+        return
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+            dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+        x = (inputs - rgb_mean) / self.rgb_max
+        x1 = x[:, :, 0, :, :]
+        x2 = x[:, :, 1, :, :]
+        x = torch.cat((x1, x2), dim=1)
+        # flownetc
+        flownetc_flow2 = self.flownetc(x)[0]
+        flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow)
+        # warp img1 to img0;
+        # magnitude of diff between img0 and and warped_img1,
+        if self.args.fp16:
+            resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]),
+                                            flownetc_flow)
+            resampled_img1 = tofp16()(resampled_img1)
+        else:
+            resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow)
+        diff_img0 = x[:, :3, :, :] - resampled_img1
+        norm_diff_img0 = self.channelnorm(diff_img0)
+        # concat img0, img1, img1->img0, flow, diff-mag ;
+        concat1 = torch.cat(
+            (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0),
+            dim=1)
+        # flownets1
+        flownets1_flow2 = self.flownets_1(concat1)[0]
+        flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow)
+        # warp img1 to img0 using flownets1;
+        # magnitude of diff between img0 and and warped_img1
+        if self.args.fp16:
+            resampled_img1 = self.resample2(tofp32()(x[:, 3:, :, :]),
+                                            flownets1_flow)
+            resampled_img1 = tofp16()(resampled_img1)
+        else:
+            resampled_img1 = self.resample2(x[:, 3:, :, :], flownets1_flow)
+        diff_img0 = x[:, :3, :, :] - resampled_img1
+        norm_diff_img0 = self.channelnorm(diff_img0)
+        # concat img0, img1, img1->img0, flow, diff-mag
+        concat2 = torch.cat(
+            (x,
+             resampled_img1,
+             flownets1_flow /
+             self.div_flow,
+             norm_diff_img0),
+            dim=1)
+        # flownets2
+        flownets2_flow2 = self.flownets_2(concat2)[0]
+        flownets2_flow = self.upsample4(flownets2_flow2 * self.div_flow)
+        norm_flownets2_flow = self.channelnorm(flownets2_flow)
+        if self.args.fp16:
+            diff_flownets2_flow = self.resample4(tofp32()(x[:, 3:, :, :]),
+                                                 flownets2_flow)
+            diff_flownets2_flow = tofp16()(diff_flownets2_flow)
+        else:
+            diff_flownets2_flow = self.resample4(x[:, 3:, :, :], flownets2_flow)
+        diff_flownets2_img1 = self.channelnorm(
+            (x[:, :3, :, :] - diff_flownets2_flow))
+        # flownetsd
+        flownetsd_flow2 = self.flownets_d(x)[0]
+        flownetsd_flow = self.upsample3(flownetsd_flow2 / self.div_flow)
+        norm_flownetsd_flow = self.channelnorm(flownetsd_flow)
+        if self.args.fp16:
+            diff_flownetsd_flow = self.resample3(tofp32()(x[:, 3:, :, :]),
+                                                 flownetsd_flow)
+            diff_flownetsd_flow = tofp16()(diff_flownetsd_flow)
+        else:
+            diff_flownetsd_flow = self.resample3(x[:, 3:, :, :], flownetsd_flow)
+        diff_flownetsd_img1 = self.channelnorm(
+            (x[:, :3, :, :] - diff_flownetsd_flow))
+        # concat img1 flownetsd, flownets2, norm_flownetsd,
+        # norm_flownets2, diff_flownetsd_img1, diff_flownets2_img1
+        concat3 = torch.cat((x[:, :3, :, :], flownetsd_flow, flownets2_flow,
+                             norm_flownetsd_flow, norm_flownets2_flow,
+                             diff_flownetsd_img1, diff_flownets2_img1), dim=1)
+        flownetfusion_flow = self.flownetfusion(concat3)
+        return flownetfusion_flow
+
+
+class FlowNet2C(flownet_c.FlowNetC):
+    def __init__(self, args, use_batch_norm=False, div_flow=20):
+        super(
+            FlowNet2C,
+            self).__init__(
+            args,
+            use_batch_norm=use_batch_norm,
+            div_flow=20)
+        self.rgb_max = args.rgb_max
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+            dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+        x = (inputs - rgb_mean) / self.rgb_max
+        x1 = x[:, :, 0, :, :]
+        x2 = x[:, :, 1, :, :]
+        # FlownetC top input stream
+        out_conv1a = self.conv1(x1)
+        out_conv2a = self.conv2(out_conv1a)
+        out_conv3a = self.conv3(out_conv2a)
+        # FlownetC bottom input stream
+        out_conv1b = self.conv1(x2)
+        out_conv2b = self.conv2(out_conv1b)
+        out_conv3b = self.conv3(out_conv2b)
+        # Merge streams
+        out_corr = self.corr(out_conv3a, out_conv3b)  # False
+        out_corr = self.corr_activation(out_corr)
+        # Redirect top input stream and concatenate
+        out_conv_redir = self.conv_redir(out_conv3a)
+        in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1)
+        # Merged conv layers
+        out_conv3_1 = self.conv3_1(in_conv3_1)
+        out_conv4 = self.conv4_1(self.conv4(out_conv3_1))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+        flow6 = self.predict_flow6(out_conv6)
+        flow6_up = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+        concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+        flow5 = self.predict_flow5(concat5)
+        flow5_up = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+        flow4 = self.predict_flow4(concat4)
+        flow4_up = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        concat3 = torch.cat((out_conv3_1, out_deconv3, flow4_up), 1)
+        flow3 = self.predict_flow3(concat3)
+        flow3_up = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+        concat2 = torch.cat((out_conv2a, out_deconv2, flow3_up), 1)
+        flow2 = self.predict_flow2(concat2)
+        if self.training:
+            return flow2, flow3, flow4, flow5, flow6
+        else:
+            return self.upsample1(flow2 * self.div_flow)
+
+
+class FlowNet2S(flownet_s.FlowNetS):
+    def __init__(self, args, use_batch_norm=False, div_flow=20):
+        super(FlowNet2S, self).__init__(args, input_channels=6,
+                                        use_batch_norm=use_batch_norm)
+        self.rgb_max = args.rgb_max
+        self.div_flow = div_flow
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+            dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+        x = (inputs - rgb_mean) / self.rgb_max
+        x = torch.cat((x[:, :, 0, :, :], x[:, :, 1, :, :]), dim=1)
+        out_conv1 = self.conv1(x)
+        out_conv2 = self.conv2(out_conv1)
+        out_conv3 = self.conv3_1(self.conv3(out_conv2))
+        out_conv4 = self.conv4_1(self.conv4(out_conv3))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+        flow6 = self.predict_flow6(out_conv6)
+        flow6_up = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+        concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+        flow5 = self.predict_flow5(concat5)
+        flow5_up = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+        flow4 = self.predict_flow4(concat4)
+        flow4_up = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
+        flow3 = self.predict_flow3(concat3)
+        flow3_up = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+        concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
+        flow2 = self.predict_flow2(concat2)
+        if self.training:
+            return flow2, flow3, flow4, flow5, flow6
+        else:
+            return self.upsample1(flow2 * self.div_flow)
+
+
+class FlowNet2SD(flownet_sd.FlowNetSD):
+    def __init__(self, args, use_batch_norm=False, div_flow=20):
+        super(FlowNet2SD, self).__init__(args, use_batch_norm=use_batch_norm)
+        self.rgb_max = args.rgb_max
+        self.div_flow = div_flow
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+            dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+        x = (inputs - rgb_mean) / self.rgb_max
+        x = torch.cat((x[:, :, 0, :, :], x[:, :, 1, :, :]), dim=1)
+        out_conv0 = self.conv0(x)
+        out_conv1 = self.conv1_1(self.conv1(out_conv0))
+        out_conv2 = self.conv2_1(self.conv2(out_conv1))
+        out_conv3 = self.conv3_1(self.conv3(out_conv2))
+        out_conv4 = self.conv4_1(self.conv4(out_conv3))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+        flow6 = self.predict_flow6(out_conv6)
+        flow6_up = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+        concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+        out_interconv5 = self.inter_conv5(concat5)
+        flow5 = self.predict_flow5(out_interconv5)
+        flow5_up = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+        out_interconv4 = self.inter_conv4(concat4)
+        flow4 = self.predict_flow4(out_interconv4)
+        flow4_up = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
+        out_interconv3 = self.inter_conv3(concat3)
+        flow3 = self.predict_flow3(out_interconv3)
+        flow3_up = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+        concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
+        out_interconv2 = self.inter_conv2(concat2)
+        flow2 = self.predict_flow2(out_interconv2)
+        if self.training:
+            return flow2, flow3, flow4, flow5, flow6
+        else:
+            return self.upsample1(flow2 * self.div_flow)
+
+
+class FlowNet2CS(nn.Module):
+    def __init__(self, args, use_batch_norm=False, div_flow=20.):
+        super(FlowNet2CS, self).__init__()
+        self.use_batch_norm = use_batch_norm
+        self.div_flow = div_flow
+        self.rgb_max = args.rgb_max
+        self.args = args
+        self.channelnorm = channelnorm.ChannelNorm()
+        # First Block (FlowNetC)
+        self.flownetc = flownet_c.FlowNetC(
+            args, use_batch_norm=self.use_batch_norm)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+        self.args = args
+        # if args.fp16:
+        #     self.resample1 = nn.Sequential(
+        #         tofp32(), resample2d.Resample2d(), tofp16())
+        # else:
+        self.resample1 = resample2d.Resample2d()
+        # Block (FlowNetS1)
+        self.flownets_1 = flownet_s.FlowNetS(
+            args, use_batch_norm=self.use_batch_norm)
+        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform(m.bias)
+                init.xavier_uniform(m.weight)
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform(m.bias)
+                init.xavier_uniform(m.weight)
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+            dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+        x = (inputs - rgb_mean) / self.rgb_max
+        x1 = x[:, :, 0, :, :]
+        x2 = x[:, :, 1, :, :]
+        x = torch.cat((x1, x2), dim=1)
+        # flownetc
+        flownetc_flow2 = self.flownetc(x)[0]
+        flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow)
+        # warp img1 to img0;
+        # magnitude of diff between img0 and and warped_img1,
+        if self.args.fp16:
+            resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]),
+                                            flownetc_flow)
+            resampled_img1 = tofp16()(resampled_img1)
+        else:
+            resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow)
+        diff_img0 = x[:, :3, :, :] - resampled_img1
+        norm_diff_img0 = self.channelnorm(diff_img0)
+        # concat img0, img1, img1->img0, flow, diff-mag ;
+        concat1 = torch.cat(
+            (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0),
+            dim=1)
+        # flownets1
+        flownets1_flow2 = self.flownets_1(concat1)[0]
+        flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow)
+        return flownets1_flow
+
+
+class FlowNet2CSS(nn.Module):
+    def __init__(self, args, use_batch_norm=False, div_flow=20.):
+        super(FlowNet2CSS, self).__init__()
+        self.use_batch_norm = use_batch_norm
+        self.div_flow = div_flow
+        self.rgb_max = args.rgb_max
+        self.args = args
+        self.channelnorm = channelnorm.ChannelNorm()
+        # First Block (FlowNetC)
+        self.flownetc = flownet_c.FlowNetC(
+            args, use_batch_norm=self.use_batch_norm)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+        self.args = args
+        # if args.fp16:
+        #     self.resample1 = nn.Sequential(
+        #         tofp32(), resample2d.Resample2d(), tofp16())
+        # else:
+        self.resample1 = resample2d.Resample2d()
+        # Block (FlowNetS1)
+        self.flownets_1 = flownet_s.FlowNetS(
+            args, use_batch_norm=self.use_batch_norm)
+        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+        # if args.fp16:
+        #     self.resample2 = nn.Sequential(
+        #         tofp32(), resample2d.Resample2d(), tofp16())
+        # else:
+        self.resample2 = resample2d.Resample2d()
+        # Block (FlowNetS2)
+        self.flownets_2 = flownet_s.FlowNetS(
+            args, use_batch_norm=self.use_batch_norm)
+        self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest',
+                                     align_corners=False)
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform(m.bias)
+                init.xavier_uniform(m.weight)
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform(m.bias)
+                init.xavier_uniform(m.weight)
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+            dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+        x = (inputs - rgb_mean) / self.rgb_max
+        x1 = x[:, :, 0, :, :]
+        x2 = x[:, :, 1, :, :]
+        x = torch.cat((x1, x2), dim=1)
+        # flownetc
+        flownetc_flow2 = self.flownetc(x)[0]
+        flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow)
+        # Warp img1 to img0;
+        # Magnitude of diff between img0 and and warped_img1,
+        if self.args.fp16:
+            resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]),
+                                            flownetc_flow)
+            resampled_img1 = tofp16()(resampled_img1)
+        else:
+            resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow)
+        diff_img0 = x[:, :3, :, :] - resampled_img1
+        norm_diff_img0 = self.channelnorm(diff_img0)
+        # concat img0, img1, img1->img0, flow, diff-mag ;
+        concat1 = torch.cat(
+            (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0),
+            dim=1)
+        # flownets1
+        flownets1_flow2 = self.flownets_1(concat1)[0]
+        flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow)
+        # Warp img1 to img0 using flownets1;
+        # magnitude of diff between img0 and and warped_img1
+        if self.args.fp16:
+            resampled_img1 = self.resample2(tofp32()(x[:, 3:, :, :]),
+                                            flownets1_flow)
+            resampled_img1 = tofp16()(resampled_img1)
+        else:
+            resampled_img1 = self.resample2(x[:, 3:, :, :], flownets1_flow)
+        diff_img0 = x[:, :3, :, :] - resampled_img1
+        norm_diff_img0 = self.channelnorm(diff_img0)
+        # concat img0, img1, img1->img0, flow, diff-mag
+        concat2 = torch.cat(
+            (x,
+             resampled_img1,
+             flownets1_flow /
+             self.div_flow,
+             norm_diff_img0),
+            dim=1)
+        # flownets2
+        flownets2_flow2 = self.flownets_2(concat2)[0]
+        flownets2_flow = self.upsample3(flownets2_flow2 * self.div_flow)
+        return flownets2_flow
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/__init__.py b/imaginaire/third_party/flow_net/flownet2/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb7ad34e7ce1c37ea1653d73cde323cdb5569e4
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6b3719c26eda72d61429c3c52b49707cf48c558
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py
@@ -0,0 +1,160 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+from torch.nn import init
+import correlation
+import torch
+import torch.nn as nn
+from .submodules import conv, predict_flow, deconv, tofp16, tofp32
+
+
+class FlowNetC(nn.Module):
+    def __init__(self, args, use_batch_norm=True, div_flow=20):
+        r"""FlowNet2 C module. Check out the FlowNet2 paper for more details
+        https://arxiv.org/abs/1612.01925
+
+        Args:
+            args (obj): Network initialization arguments
+            use_batch_norm (bool): Use batch norm or not. Default is true.
+            div_flow (int): Flow devision factor. Default is 20.
+        """
+        super(FlowNetC, self).__init__()
+
+        self.use_batch_norm = use_batch_norm
+        self.div_flow = div_flow
+
+        self.conv1 = conv(self.use_batch_norm, 3, 64, kernel_size=7, stride=2)
+        self.conv2 = conv(self.use_batch_norm, 64, 128, kernel_size=5, stride=2)
+        self.conv3 = conv(self.use_batch_norm, 128, 256, kernel_size=5,
+                          stride=2)
+        self.conv_redir = conv(self.use_batch_norm, 256, 32,
+                               kernel_size=1, stride=1)
+        self.args = args
+        # if args.fp16:
+        #     self.corr = nn.Sequential(
+        #         tofp32(),
+        #         correlation.Correlation(pad_size=20, kernel_size=1,
+        #                                 max_displacement=20, stride1=1,
+        #                                 stride2=2, corr_multiply=1),
+        #         tofp16())
+        # else:
+        self.corr = correlation.Correlation(pad_size=20, kernel_size=1,
+                                            max_displacement=20, stride1=1,
+                                            stride2=2, corr_multiply=1)
+
+        self.corr_activation = nn.LeakyReLU(0.1, inplace=True)
+        self.conv3_1 = conv(self.use_batch_norm, 473, 256)
+        self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2)
+        self.conv4_1 = conv(self.use_batch_norm, 512, 512)
+        self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2)
+        self.conv5_1 = conv(self.use_batch_norm, 512, 512)
+        self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2)
+        self.conv6_1 = conv(self.use_batch_norm, 1024, 1024)
+
+        self.deconv5 = deconv(1024, 512)
+        self.deconv4 = deconv(1026, 256)
+        self.deconv3 = deconv(770, 128)
+        self.deconv2 = deconv(386, 64)
+
+        self.predict_flow6 = predict_flow(1024)
+        self.predict_flow5 = predict_flow(1026)
+        self.predict_flow4 = predict_flow(770)
+        self.predict_flow3 = predict_flow(386)
+        self.predict_flow2 = predict_flow(194)
+
+        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(
+            2, 2, 4, 2, 1, bias=True)
+        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(
+            2, 2, 4, 2, 1, bias=True)
+        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(
+            2, 2, 4, 2, 1, bias=True)
+        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(
+            2, 2, 4, 2, 1, bias=True)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input tensors of concatenated images.
+        Returns:
+            flow2 (tensor): Output flow tensors.
+        """
+        x1 = x[:, 0:3, :, :]
+        x2 = x[:, 3::, :, :]
+
+        out_conv1a = self.conv1(x1)
+        out_conv2a = self.conv2(out_conv1a)
+        out_conv3a = self.conv3(out_conv2a)
+
+        # FlownetC bottom input stream
+        out_conv1b = self.conv1(x2)
+
+        out_conv2b = self.conv2(out_conv1b)
+        out_conv3b = self.conv3(out_conv2b)
+
+        # Merge streams
+        if self.args.fp16:
+            out_corr = self.corr(tofp32()(out_conv3a),
+                                 tofp32()(out_conv3b))  # False
+            out_corr = tofp16()(out_corr)
+        else:
+            out_corr = self.corr(out_conv3a, out_conv3b)  # False
+        out_corr = self.corr_activation(out_corr)
+
+        # Redirect top input stream and concatenate
+        out_conv_redir = self.conv_redir(out_conv3a)
+
+        in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1)
+
+        # Merged conv layers
+        out_conv3_1 = self.conv3_1(in_conv3_1)
+
+        out_conv4 = self.conv4_1(self.conv4(out_conv3_1))
+
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6 = self.predict_flow6(out_conv6)
+        flow6_up = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+
+        concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+
+        flow5 = self.predict_flow5(concat5)
+        flow5_up = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+
+        flow4 = self.predict_flow4(concat4)
+        flow4_up = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        concat3 = torch.cat((out_conv3_1, out_deconv3, flow4_up), 1)
+
+        flow3 = self.predict_flow3(concat3)
+        flow3_up = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+        concat2 = torch.cat((out_conv2a, out_deconv2, flow3_up), 1)
+
+        flow2 = self.predict_flow2(concat2)
+
+        if self.training:
+            return flow2, flow3, flow4, flow5, flow6
+        else:
+            return flow2,
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..edddd2446d906bfc0b93df47b6f18a45ac42bc79
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py
@@ -0,0 +1,82 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+from torch.nn import init
+import torch
+import torch.nn as nn
+from .submodules import conv, i_conv, predict_flow, deconv
+
+
+class FlowNetFusion(nn.Module):
+    r"""FlowNet2 Fusion module. Check out the FlowNet2 paper for more details
+    https://arxiv.org/abs/1612.01925
+
+    Args:
+        args (obj): Network initialization arguments
+        use_batch_norm (bool): Use batch norm or not. Default is true.
+    """
+    def __init__(self, args, use_batch_norm=True):
+        super(FlowNetFusion, self).__init__()
+
+        self.use_batch_norm = use_batch_norm
+        self.conv0 = conv(self.use_batch_norm, 11, 64)
+        self.conv1 = conv(self.use_batch_norm, 64, 64, stride=2)
+        self.conv1_1 = conv(self.use_batch_norm, 64, 128)
+        self.conv2 = conv(self.use_batch_norm, 128, 128, stride=2)
+        self.conv2_1 = conv(self.use_batch_norm, 128, 128)
+
+        self.deconv1 = deconv(128, 32)
+        self.deconv0 = deconv(162, 16)
+
+        self.inter_conv1 = i_conv(self.use_batch_norm, 162, 32)
+        self.inter_conv0 = i_conv(self.use_batch_norm, 82, 16)
+
+        self.predict_flow2 = predict_flow(128)
+        self.predict_flow1 = predict_flow(32)
+        self.predict_flow0 = predict_flow(16)
+
+        self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+        self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input tensors of concatenated images.
+        Returns:
+            flow2 (tensor): Output flow tensors.
+        """
+        out_conv0 = self.conv0(x)
+        out_conv1 = self.conv1_1(self.conv1(out_conv0))
+        out_conv2 = self.conv2_1(self.conv2(out_conv1))
+
+        flow2 = self.predict_flow2(out_conv2)
+        flow2_up = self.upsampled_flow2_to_1(flow2)
+        out_deconv1 = self.deconv1(out_conv2)
+
+        concat1 = torch.cat((out_conv1, out_deconv1, flow2_up), 1)
+        out_interconv1 = self.inter_conv1(concat1)
+        flow1 = self.predict_flow1(out_interconv1)
+        flow1_up = self.upsampled_flow1_to_0(flow1)
+        out_deconv0 = self.deconv0(concat1)
+
+        concat0 = torch.cat((out_conv0, out_deconv0, flow1_up), 1)
+        out_interconv0 = self.inter_conv0(concat0)
+        flow0 = self.predict_flow0(out_interconv0)
+
+        return flow0
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba8a76c5a5d66354e07aad2522a782961a50c24c
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py
@@ -0,0 +1,121 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+'''
+Portions of this code copyright 2017, Clement Pinard
+'''
+from torch.nn import init
+import torch
+import torch.nn as nn
+from .submodules import conv, predict_flow, deconv
+
+
+class FlowNetS(nn.Module):
+    r"""FlowNet2 S module. Check out the FlowNet2 paper for more details
+    https://arxiv.org/abs/1612.01925
+
+    Args:
+        args (obj): Network initialization arguments
+        input_channels (int): Number of input channels. Default is 12.
+        use_batch_norm (bool): Use batch norm or not. Default is true.
+    """
+    def __init__(self, args, input_channels=12, use_batch_norm=True):
+        super(FlowNetS, self).__init__()
+
+        self.use_batch_norm = use_batch_norm
+        self.conv1 = conv(
+            self.use_batch_norm,
+            input_channels,
+            64,
+            kernel_size=7,
+            stride=2)
+        self.conv2 = conv(self.use_batch_norm, 64, 128, kernel_size=5, stride=2)
+        self.conv3 = conv(self.use_batch_norm, 128, 256, kernel_size=5,
+                          stride=2)
+        self.conv3_1 = conv(self.use_batch_norm, 256, 256)
+        self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2)
+        self.conv4_1 = conv(self.use_batch_norm, 512, 512)
+        self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2)
+        self.conv5_1 = conv(self.use_batch_norm, 512, 512)
+        self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2)
+        self.conv6_1 = conv(self.use_batch_norm, 1024, 1024)
+
+        self.deconv5 = deconv(1024, 512)
+        self.deconv4 = deconv(1026, 256)
+        self.deconv3 = deconv(770, 128)
+        self.deconv2 = deconv(386, 64)
+
+        self.predict_flow6 = predict_flow(1024)
+        self.predict_flow5 = predict_flow(1026)
+        self.predict_flow4 = predict_flow(770)
+        self.predict_flow3 = predict_flow(386)
+        self.predict_flow2 = predict_flow(194)
+
+        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(
+            2, 2, 4, 2, 1, bias=False)
+        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(
+            2, 2, 4, 2, 1, bias=False)
+        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(
+            2, 2, 4, 2, 1, bias=False)
+        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(
+            2, 2, 4, 2, 1, bias=False)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input tensors of concatenated images.
+        Returns:
+            flow2 (tensor): Output flow tensors.
+        """
+        out_conv1 = self.conv1(x)
+
+        out_conv2 = self.conv2(out_conv1)
+        out_conv3 = self.conv3_1(self.conv3(out_conv2))
+        out_conv4 = self.conv4_1(self.conv4(out_conv3))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6 = self.predict_flow6(out_conv6)
+        flow6_up = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+
+        concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+        flow5 = self.predict_flow5(concat5)
+        flow5_up = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+
+        concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+        flow4 = self.predict_flow4(concat4)
+        flow4_up = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+
+        concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
+        flow3 = self.predict_flow3(concat3)
+        flow3_up = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+
+        concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
+        flow2 = self.predict_flow2(concat2)
+
+        if self.training:
+            return flow2, flow3, flow4, flow5, flow6
+        else:
+            return flow2,
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f4340347252a9591d7540689abaae821d759060
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py
@@ -0,0 +1,121 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+import torch
+import torch.nn as nn
+from .submodules import conv, i_conv, predict_flow, deconv
+from torch.nn import init
+
+
+class FlowNetSD(nn.Module):
+    r"""FlowNet2 SD module. Check out the FlowNet2 paper for more details
+    https://arxiv.org/abs/1612.01925
+
+    Args:
+        args (obj): Network initialization arguments
+        use_batch_norm (bool): Use batch norm or not. Default is true.
+    """
+    def __init__(self, args, use_batch_norm=True):
+        super(FlowNetSD, self).__init__()
+
+        self.use_batch_norm = use_batch_norm
+        self.conv0 = conv(self.use_batch_norm, 6, 64)
+        self.conv1 = conv(self.use_batch_norm, 64, 64, stride=2)
+        self.conv1_1 = conv(self.use_batch_norm, 64, 128)
+        self.conv2 = conv(self.use_batch_norm, 128, 128, stride=2)
+        self.conv2_1 = conv(self.use_batch_norm, 128, 128)
+        self.conv3 = conv(self.use_batch_norm, 128, 256, stride=2)
+        self.conv3_1 = conv(self.use_batch_norm, 256, 256)
+        self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2)
+        self.conv4_1 = conv(self.use_batch_norm, 512, 512)
+        self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2)
+        self.conv5_1 = conv(self.use_batch_norm, 512, 512)
+        self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2)
+        self.conv6_1 = conv(self.use_batch_norm, 1024, 1024)
+
+        self.deconv5 = deconv(1024, 512)
+        self.deconv4 = deconv(1026, 256)
+        self.deconv3 = deconv(770, 128)
+        self.deconv2 = deconv(386, 64)
+
+        self.inter_conv5 = i_conv(self.use_batch_norm, 1026, 512)
+        self.inter_conv4 = i_conv(self.use_batch_norm, 770, 256)
+        self.inter_conv3 = i_conv(self.use_batch_norm, 386, 128)
+        self.inter_conv2 = i_conv(self.use_batch_norm, 194, 64)
+
+        self.predict_flow6 = predict_flow(1024)
+        self.predict_flow5 = predict_flow(512)
+        self.predict_flow4 = predict_flow(256)
+        self.predict_flow3 = predict_flow(128)
+        self.predict_flow2 = predict_flow(64)
+
+        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+                                     align_corners=False)
+
+    def forward(self, x):
+        r"""
+
+        Args:
+            x (tensor): Input tensors of concatenated images.
+        Returns:
+            flow2 (tensor): Output flow tensors.
+        """
+        out_conv0 = self.conv0(x)
+        out_conv1 = self.conv1_1(self.conv1(out_conv0))
+        out_conv2 = self.conv2_1(self.conv2(out_conv1))
+
+        out_conv3 = self.conv3_1(self.conv3(out_conv2))
+        out_conv4 = self.conv4_1(self.conv4(out_conv3))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6 = self.predict_flow6(out_conv6)
+        flow6_up = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+
+        concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+        out_interconv5 = self.inter_conv5(concat5)
+        flow5 = self.predict_flow5(out_interconv5)
+
+        flow5_up = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+
+        concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+        out_interconv4 = self.inter_conv4(concat4)
+        flow4 = self.predict_flow4(out_interconv4)
+        flow4_up = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+
+        concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
+        out_interconv3 = self.inter_conv3(concat3)
+        flow3 = self.predict_flow3(out_interconv3)
+        flow3_up = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+
+        concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
+        out_interconv2 = self.inter_conv2(concat2)
+        flow2 = self.predict_flow2(out_interconv2)
+
+        if self.training:
+            return flow2, flow3, flow4, flow5, flow6
+        else:
+            return flow2,
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/submodules.py b/imaginaire/third_party/flow_net/flownet2/networks/submodules.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ab504401c1473bcc52ae4a1029afd74eed6d11
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/submodules.py
@@ -0,0 +1,113 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def conv(use_batch_norm, in_planes, out_planes, kernel_size=3, stride=1):
+    if use_batch_norm:
+        return nn.Sequential(
+            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
+                      stride=stride, padding=(kernel_size - 1) // 2,
+                      bias=False),
+            nn.BatchNorm2d(out_planes),
+            nn.LeakyReLU(0.1, inplace=True)
+        )
+    else:
+        return nn.Sequential(
+            nn.Conv2d(
+                in_planes,
+                out_planes,
+                kernel_size=kernel_size,
+                stride=stride,
+                padding=(
+                    kernel_size - 1) // 2,
+                bias=True),
+            nn.LeakyReLU(
+                0.1,
+                inplace=True))
+
+
+def i_conv(use_batch_norm, in_planes, out_planes, kernel_size=3, stride=1,
+           bias=True):
+    if use_batch_norm:
+        return nn.Sequential(
+            nn.Conv2d(
+                in_planes,
+                out_planes,
+                kernel_size=kernel_size,
+                stride=stride,
+                padding=(
+                    kernel_size - 1) // 2,
+                bias=bias),
+            nn.BatchNorm2d(out_planes),
+        )
+    else:
+        return nn.Sequential(
+            nn.Conv2d(
+                in_planes,
+                out_planes,
+                kernel_size=kernel_size,
+                stride=stride,
+                padding=(
+                    kernel_size -
+                    1) //
+                2,
+                bias=bias),
+        )
+
+
+def predict_flow(in_planes):
+    return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1,
+                     bias=True)
+
+
+def deconv(in_planes, out_planes):
+    return nn.Sequential(
+        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2,
+                           padding=1, bias=True),
+        nn.LeakyReLU(0.1, inplace=True)
+    )
+
+
+class tofp16(nn.Module):
+    def __init__(self):
+        super(tofp16, self).__init__()
+
+    def forward(self, input):
+        return input.half()
+
+
+class tofp32(nn.Module):
+    def __init__(self):
+        super(tofp32, self).__init__()
+
+    def forward(self, input):
+        return input.float()
+
+
+def init_deconv_bilinear(weight):
+    f_shape = weight.size()
+    heigh, width = f_shape[-2], f_shape[-1]
+    f = np.ceil(width / 2.0)
+    c = (2 * f - 1 - f % 2) / (2.0 * f)
+    bilinear = np.zeros([heigh, width])
+    for x in range(width):
+        for y in range(heigh):
+            value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
+            bilinear[x, y] = value
+    weight.data.fill_(0.)
+    for i in range(f_shape[0]):
+        for j in range(f_shape[1]):
+            weight.data[i, j, :, :] = torch.from_numpy(bilinear)
+
+
+def save_grad(grads, name):
+    def hook(grad):
+        grads[name] = grad
+    return hook
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/__init__.py b/imaginaire/third_party/flow_net/flownet2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bffeba58a93de4379c8e9ed54af58b56baa13eb
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py
@@ -0,0 +1,219 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import matplotlib.pyplot as plt
+import os.path
+
+TAG_CHAR = np.array([202021.25], np.float32)
+
+
+def readFlow(fn):
+    """ Read .flo file in Middlebury format"""
+    # Code adapted from:
+    # http://stackoverflow.com/questions/28013200/
+    # reading-middlebury-flow-files-with-python-bytes-array-numpy
+
+    # WARNING: this will work on little-endian architectures
+    # (eg Intel x86) only!
+    # print 'fn = %s'%(fn)
+    with open(fn, 'rb') as f:
+        magic = np.fromfile(f, np.float32, count=1)
+        if 202021.25 != magic:
+            print('Magic number incorrect. Invalid .flo file')
+            return None
+        else:
+            w = np.fromfile(f, np.int32, count=1)
+            h = np.fromfile(f, np.int32, count=1)
+            # print 'Reading %d x %d flo file\n' % (w, h)
+            data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
+            # Reshape data into 3D array (columns, rows, bands)
+            # The reshape here is for visualization, the original code is
+            # (w,h,2)
+            return np.resize(data, (int(h), int(w), 2))
+
+
+def writeFlow(filename, uv, v=None):
+    """ Write optical flow to file.
+
+    If v is None, uv is assumed to contain both u and v channels,
+    stacked in deep.
+    Original code by Deqing Sun, adapted from Daniel Scharstein.
+    """
+    nBands = 2
+
+    if v is None:
+        assert(uv.ndim == 3)
+        assert(uv.shape[2] == 2)
+        u = uv[:, :, 0]
+        v = uv[:, :, 1]
+    else:
+        u = uv
+
+    assert(u.shape == v.shape)
+    height, width = u.shape
+    f = open(filename, 'wb')
+    # write the header
+    f.write(TAG_CHAR)
+    np.array(width).astype(np.int32).tofile(f)
+    np.array(height).astype(np.int32).tofile(f)
+    # arrange into matrix form
+    tmp = np.zeros((height, width * nBands))
+    tmp[:, np.arange(width) * 2] = u
+    tmp[:, np.arange(width) * 2 + 1] = v
+    tmp.astype(np.float32).tofile(f)
+    f.close()
+
+
+# ref: https://github.com/sampepose/flownet2-tf/
+# blob/18f87081db44939414fc4a48834f9e0da3e69f4c/src/flowlib.py#L240
+def visulize_flow_file(flow_filename, save_dir=None):
+    flow_data = readFlow(flow_filename)
+    img = flow2img(flow_data)
+    # plt.imshow(img)
+    # plt.show()
+    if save_dir:
+        idx = flow_filename.rfind("/") + 1
+        plt.imsave(os.path.join(save_dir, "%s-vis.png" %
+                                flow_filename[idx:-4]), img)
+
+
+def flow2img(flow_data):
+    """
+    convert optical flow into color image
+    :param flow_data:
+    :return: color image
+    """
+    # print(flow_data.shape)
+    # print(type(flow_data))
+    u = flow_data[:, :, 0]
+    v = flow_data[:, :, 1]
+
+    UNKNOW_FLOW_THRESHOLD = 1e7
+    pr1 = abs(u) > UNKNOW_FLOW_THRESHOLD
+    pr2 = abs(v) > UNKNOW_FLOW_THRESHOLD
+    idx_unknown = (pr1 | pr2)
+    u[idx_unknown] = v[idx_unknown] = 0
+
+    # get max value in each direction
+    maxu = -999.
+    maxv = -999.
+    minu = 999.
+    minv = 999.
+    maxu = max(maxu, np.max(u))
+    maxv = max(maxv, np.max(v))
+    minu = min(minu, np.min(u))
+    minv = min(minv, np.min(v))
+
+    rad = np.sqrt(u ** 2 + v ** 2)
+    maxrad = max(-1, np.max(rad))
+    u = u / maxrad + np.finfo(float).eps
+    v = v / maxrad + np.finfo(float).eps
+
+    img = compute_color(u, v)
+
+    idx = np.repeat(idx_unknown[:, :, np.newaxis], 3, axis=2)
+    img[idx] = 0
+
+    return np.uint8(img)
+
+
+def compute_color(u, v):
+    """
+    compute optical flow color map
+    :param u: horizontal optical flow
+    :param v: vertical optical flow
+    :return:
+    """
+
+    height, width = u.shape
+    img = np.zeros((height, width, 3))
+
+    NAN_idx = np.isnan(u) | np.isnan(v)
+    u[NAN_idx] = v[NAN_idx] = 0
+
+    colorwheel = make_color_wheel()
+    ncols = np.size(colorwheel, 0)
+
+    rad = np.sqrt(u ** 2 + v ** 2)
+
+    a = np.arctan2(-v, -u) / np.pi
+
+    fk = (a + 1) / 2 * (ncols - 1) + 1
+
+    k0 = np.floor(fk).astype(int)
+
+    k1 = k0 + 1
+    k1[k1 == ncols + 1] = 1
+    f = fk - k0
+
+    for i in range(0, np.size(colorwheel, 1)):
+        tmp = colorwheel[:, i]
+        col0 = tmp[k0 - 1] / 255
+        col1 = tmp[k1 - 1] / 255
+        col = (1 - f) * col0 + f * col1
+
+        idx = rad <= 1
+        col[idx] = 1 - rad[idx] * (1 - col[idx])
+        notidx = np.logical_not(idx)
+
+        col[notidx] *= 0.75
+        img[:, :, i] = np.uint8(np.floor(255 * col * (1 - NAN_idx)))
+
+    return img
+
+
+def make_color_wheel():
+    """
+    Generate color wheel according Middlebury color code
+    :return: Color wheel
+    """
+    RY = 15
+    YG = 6
+    GC = 4
+    CB = 11
+    BM = 13
+    MR = 6
+
+    ncols = RY + YG + GC + CB + BM + MR
+
+    colorwheel = np.zeros([ncols, 3])
+
+    col = 0
+
+    # RY
+    colorwheel[0:RY, 0] = 255
+    colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
+    col += RY
+
+    # YG
+    colorwheel[col:col + YG, 0] = 255 - \
+        np.transpose(np.floor(255 * np.arange(0, YG) / YG))
+    colorwheel[col:col + YG, 1] = 255
+    col += YG
+
+    # GC
+    colorwheel[col:col + GC, 1] = 255
+    colorwheel[col:col + GC,
+               2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
+    col += GC
+
+    # CB
+    colorwheel[col:col + CB, 1] = 255 - \
+        np.transpose(np.floor(255 * np.arange(0, CB) / CB))
+    colorwheel[col:col + CB, 2] = 255
+    col += CB
+
+    # BM
+    colorwheel[col:col + BM, 2] = 255
+    colorwheel[col:col + BM,
+               0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
+    col += + BM
+
+    # MR
+    colorwheel[col:col + MR, 2] = 255 - \
+        np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+    colorwheel[col:col + MR, 0] = 255
+
+    return colorwheel
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac1e3f83d6179afcb266e5923af5dd54fc3dd3fc
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py
@@ -0,0 +1,23 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+from os.path import splitext
+from scipy.misc import imread
+from . import flow_utils
+
+
+def read_gen(file_name):
+    ext = splitext(file_name)[-1]
+    if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+        im = imread(file_name)
+        if im.shape[2] > 3:
+            return im[:, :, :3]
+        else:
+            return im
+    elif ext == '.bin' or ext == '.raw':
+        return np.load(file_name)
+    elif ext == '.flo':
+        return flow_utils.readFlow(file_name).astype(np.float32)
+    return []
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b084c9c35b957888acea86987ab25073e43feac0
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py
@@ -0,0 +1,275 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def parse_flownetc(modules, weights, biases):
+    keys = [
+        'conv1',
+        'conv2',
+        'conv3',
+        'conv_redir',
+        'conv3_1',
+        'conv4',
+        'conv4_1',
+        'conv5',
+        'conv5_1',
+        'conv6',
+        'conv6_1',
+
+        'deconv5',
+        'deconv4',
+        'deconv3',
+        'deconv2',
+
+        'Convolution1',
+        'Convolution2',
+        'Convolution3',
+        'Convolution4',
+        'Convolution5',
+
+        'upsample_flow6to5',
+        'upsample_flow5to4',
+        'upsample_flow4to3',
+        'upsample_flow3to2',
+
+    ]
+    i = 0
+    for m in modules:
+        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+            weight = weights[keys[i]].copy()
+            bias = biases[keys[i]].copy()
+            if keys[i] == 'conv1':
+                m.weight.data[:, :, :, :] = torch.from_numpy(
+                    np.flip(weight, axis=1).copy())
+                m.bias.data[:] = torch.from_numpy(bias)
+            else:
+                m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+                m.bias.data[:] = torch.from_numpy(bias)
+
+            i = i + 1
+    return
+
+
+def parse_flownets(modules, weights, biases, param_prefix='net2_'):
+    keys = [
+        'conv1',
+        'conv2',
+        'conv3',
+        'conv3_1',
+        'conv4',
+        'conv4_1',
+        'conv5',
+        'conv5_1',
+        'conv6',
+        'conv6_1',
+
+        'deconv5',
+        'deconv4',
+        'deconv3',
+        'deconv2',
+
+        'predict_conv6',
+        'predict_conv5',
+        'predict_conv4',
+        'predict_conv3',
+        'predict_conv2',
+
+        'upsample_flow6to5',
+        'upsample_flow5to4',
+        'upsample_flow4to3',
+        'upsample_flow3to2',
+    ]
+    for i, k in enumerate(keys):
+        if 'upsample' in k:
+            keys[i] = param_prefix + param_prefix + k
+        else:
+            keys[i] = param_prefix + k
+    i = 0
+    for m in modules:
+        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+            weight = weights[keys[i]].copy()
+            bias = biases[keys[i]].copy()
+            if keys[i] == param_prefix + 'conv1':
+                m.weight.data[:, 0:3, :, :] = torch.from_numpy(
+                    np.flip(weight[:, 0:3, :, :], axis=1).copy())
+                m.weight.data[:, 3:6, :, :] = torch.from_numpy(
+                    np.flip(weight[:, 3:6, :, :], axis=1).copy())
+                m.weight.data[:, 6:9, :, :] = torch.from_numpy(
+                    np.flip(weight[:, 6:9, :, :], axis=1).copy())
+                m.weight.data[:, 9::, :, :] = torch.from_numpy(
+                    weight[:, 9:, :, :].copy())
+                if m.bias is not None:
+                    m.bias.data[:] = torch.from_numpy(bias)
+            else:
+                m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+                if m.bias is not None:
+                    m.bias.data[:] = torch.from_numpy(bias)
+            i = i + 1
+    return
+
+
+def parse_flownetsonly(modules, weights, biases, param_prefix=''):
+    keys = [
+        'conv1',
+        'conv2',
+        'conv3',
+        'conv3_1',
+        'conv4',
+        'conv4_1',
+        'conv5',
+        'conv5_1',
+        'conv6',
+        'conv6_1',
+
+        'deconv5',
+        'deconv4',
+        'deconv3',
+        'deconv2',
+
+        'Convolution1',
+        'Convolution2',
+        'Convolution3',
+        'Convolution4',
+        'Convolution5',
+
+        'upsample_flow6to5',
+        'upsample_flow5to4',
+        'upsample_flow4to3',
+        'upsample_flow3to2',
+    ]
+    for i, k in enumerate(keys):
+        if 'upsample' in k:
+            keys[i] = param_prefix + param_prefix + k
+        else:
+            keys[i] = param_prefix + k
+    i = 0
+    for m in modules:
+        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+            weight = weights[keys[i]].copy()
+            bias = biases[keys[i]].copy()
+            if keys[i] == param_prefix + 'conv1':
+                # print ("%s :"%(keys[i]), m.weight.size(), m.bias.size(),
+                # tf_w[keys[i]].shape[::-1])
+                m.weight.data[:, 0:3, :, :] = torch.from_numpy(
+                    np.flip(weight[:, 0:3, :, :], axis=1).copy())
+                m.weight.data[:, 3:6, :, :] = torch.from_numpy(
+                    np.flip(weight[:, 3:6, :, :], axis=1).copy())
+                if m.bias is not None:
+                    m.bias.data[:] = torch.from_numpy(bias)
+            else:
+                m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+                if m.bias is not None:
+                    m.bias.data[:] = torch.from_numpy(bias)
+            i = i + 1
+    return
+
+
+def parse_flownetsd(modules, weights, biases, param_prefix='netsd_'):
+    keys = [
+        'conv0',
+        'conv1',
+        'conv1_1',
+        'conv2',
+        'conv2_1',
+        'conv3',
+        'conv3_1',
+        'conv4',
+        'conv4_1',
+        'conv5',
+        'conv5_1',
+        'conv6',
+        'conv6_1',
+
+        'deconv5',
+        'deconv4',
+        'deconv3',
+        'deconv2',
+
+        'interconv5',
+        'interconv4',
+        'interconv3',
+        'interconv2',
+
+        'Convolution1',
+        'Convolution2',
+        'Convolution3',
+        'Convolution4',
+        'Convolution5',
+
+        'upsample_flow6to5',
+        'upsample_flow5to4',
+        'upsample_flow4to3',
+        'upsample_flow3to2',
+    ]
+    for i, k in enumerate(keys):
+        keys[i] = param_prefix + k
+
+    i = 0
+    for m in modules:
+        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+            weight = weights[keys[i]].copy()
+            bias = biases[keys[i]].copy()
+            if keys[i] == param_prefix + 'conv0':
+                m.weight.data[:, 0:3, :, :] = torch.from_numpy(
+                    np.flip(weight[:, 0:3, :, :], axis=1).copy())
+                m.weight.data[:, 3:6, :, :] = torch.from_numpy(
+                    np.flip(weight[:, 3:6, :, :], axis=1).copy())
+                if m.bias is not None:
+                    m.bias.data[:] = torch.from_numpy(bias)
+            else:
+                m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+                if m.bias is not None:
+                    m.bias.data[:] = torch.from_numpy(bias)
+            i = i + 1
+
+    return
+
+
+def parse_flownetfusion(modules, weights, biases, param_prefix='fuse_'):
+    keys = [
+        'conv0',
+        'conv1',
+        'conv1_1',
+        'conv2',
+        'conv2_1',
+
+        'deconv1',
+        'deconv0',
+
+        'interconv1',
+        'interconv0',
+
+        '_Convolution5',
+        '_Convolution6',
+        '_Convolution7',
+
+        'upsample_flow2to1',
+        'upsample_flow1to0',
+    ]
+    for i, k in enumerate(keys):
+        keys[i] = param_prefix + k
+
+    i = 0
+    for m in modules:
+        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+            weight = weights[keys[i]].copy()
+            bias = biases[keys[i]].copy()
+            if keys[i] == param_prefix + 'conv0':
+                m.weight.data[:, 0:3, :, :] = torch.from_numpy(
+                    np.flip(weight[:, 0:3, :, :], axis=1).copy())
+                m.weight.data[:, 3::, :, :] = torch.from_numpy(
+                    weight[:, 3:, :, :].copy())
+                if m.bias is not None:
+                    m.bias.data[:] = torch.from_numpy(bias)
+            else:
+                m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+                if m.bias is not None:
+                    m.bias.data[:] = torch.from_numpy(bias)
+            i = i + 1
+
+    return
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/tools.py b/imaginaire/third_party/flow_net/flownet2/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6c9208117824c712aa54e3a3871273b851de63c
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/tools.py
@@ -0,0 +1,194 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+import time
+import math
+import subprocess
+import shutil
+from os.path import join
+import numpy as np
+from inspect import isclass
+from pytz import timezone
+from datetime import datetime
+import inspect
+import torch
+
+
+def datestr():
+    pacific = timezone('US/Pacific')
+    now = datetime.now(pacific)
+    return '{}{:02}{:02}_{:02}{:02}'.format(
+        now.year, now.month, now.day, now.hour, now.minute)
+
+
+def module_to_dict(module, exclude=[]):
+    return dict([(x, getattr(module, x)) for x in dir(module)
+                 if isclass(getattr(module, x))
+                 and x not in exclude
+                 and getattr(module, x) not in exclude])
+
+
+class TimerBlock:
+    def __init__(self, title):
+        print(("{}".format(title)))
+
+    def __enter__(self):
+        self.start = time.clock()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.end = time.clock()
+        self.interval = self.end - self.start
+
+        if exc_type is not None:
+            self.log("Operation failed\n")
+        else:
+            self.log("Operation finished\n")
+
+    def log(self, string):
+        duration = time.clock() - self.start
+        units = 's'
+        if duration > 60:
+            duration = duration / 60.
+            units = 'm'
+        print(("  [{:.3f}{}] {}".format(duration, units, string)))
+
+    def log2file(self, fid, string):
+        fid = open(fid, 'a')
+        fid.write("%s\n" % (string))
+        fid.close()
+
+
+def add_arguments_for_module(
+        parser,
+        module,
+        argument_for_class,
+        default,
+        skip_params=[],
+        parameter_defaults={}):
+    argument_group = parser.add_argument_group(argument_for_class.capitalize())
+
+    module_dict = module_to_dict(module)
+    argument_group.add_argument(
+        '--' + argument_for_class,
+        type=str,
+        default=default,
+        choices=list(
+            module_dict.keys()))
+
+    args, unknown_args = parser.parse_known_args()
+    class_obj = module_dict[vars(args)[argument_for_class]]
+
+    argspec = inspect.getargspec(class_obj.__init__)
+
+    defaults = argspec.defaults[::-1] if argspec.defaults else None
+
+    args = argspec.args[::-1]
+    for i, arg in enumerate(args):
+        cmd_arg = '{}_{}'.format(argument_for_class, arg)
+        if arg not in skip_params + ['self', 'args']:
+            if arg in list(parameter_defaults.keys()):
+                argument_group.add_argument(
+                    '--{}'.format(cmd_arg),
+                    type=type(
+                        parameter_defaults[arg]),
+                    default=parameter_defaults[arg])
+            elif (defaults is not None and i < len(defaults)):
+                argument_group.add_argument(
+                    '--{}'.format(cmd_arg),
+                    type=type(
+                        defaults[i]),
+                    default=defaults[i])
+            else:
+                print(("[Warning]: non-default argument '{}' "
+                       "detected on class '{}'. This argument "
+                       "cannot be modified via the command line"
+                       .format(arg, module.__class__.__name__)))
+            # We don't have a good way of dealing with
+            # inferring the type of the argument
+            # TODO: try creating a custom action and using ast's infer type?
+            # else:
+            #     argument_group.add_argument('--{}'.format(
+            #     cmd_arg), required=True)
+
+
+def kwargs_from_args(args, argument_for_class):
+    argument_for_class = argument_for_class + '_'
+    return {key[len(argument_for_class):]: value for key, value in list(vars(
+        args).items()) if
+            argument_for_class in key and key != argument_for_class + 'class'}
+
+
+def format_dictionary_of_losses(labels, values):
+    try:
+        string = ', '.join([('{}: {:' +
+                             ('.3f' if value >= 0.001 else '.1e') +
+                             '}').format(name, value) for name, value in
+                            zip(labels, values)])
+    except (TypeError, ValueError) as e:
+        print((list(zip(labels, values))))
+        string = '[Log Error] ' + str(e)
+
+    return string
+
+
+class IteratorTimer():
+    def __init__(self, iterable):
+        self.iterable = iterable
+        self.iterator = self.iterable.__iter__()
+
+    def __iter__(self):
+        return self
+
+    def __len__(self):
+        return len(self.iterable)
+
+    def __next__(self):
+        start = time.time()
+        n = next(self.iterator)
+        self.last_duration = (time.time() - start)
+        return n
+
+    next = __next__
+
+
+def gpumemusage():
+    gpu_mem = subprocess.check_output(
+        "nvidia-smi | grep MiB | cut -f 3 -d '|'",
+        shell=True).replace(
+        ' ',
+        '').replace(
+            '\n',
+            '').replace(
+                'i',
+        '')
+    all_stat = [float(a) for a in gpu_mem.replace('/', '').split('MB')[:-1]]
+
+    gpu_mem = ''
+    for i in range(len(all_stat) / 2):
+        curr, tot = all_stat[2 * i], all_stat[2 * i + 1]
+        util = "%1.2f" % (100 * curr / tot) + '%'
+        cmem = str(int(math.ceil(curr / 1024.))) + 'GB'
+        gmem = str(int(math.ceil(tot / 1024.))) + 'GB'
+        gpu_mem += util + '--' + join(cmem, gmem) + ' '
+    return gpu_mem
+
+
+def update_hyperparameter_schedule(args, epoch, global_iteration, optimizer):
+    if args.schedule_lr_frequency > 0:
+        for param_group in optimizer.param_groups:
+            if (global_iteration + 1) % args.schedule_lr_frequency == 0:
+                param_group['lr'] /= float(args.schedule_lr_fraction)
+                param_group['lr'] = float(
+                    np.maximum(param_group['lr'], 0.000001))
+
+
+def save_checkpoint(state, is_best, path, prefix,
+                    filename='checkpoint.pth.tar'):
+    prefix_save = os.path.join(path, prefix)
+    name = prefix_save + '_' + filename
+    torch.save(state, name)
+    if is_best:
+        shutil.copyfile(name, prefix_save + '_model_best.pth.tar')
diff --git a/imaginaire/third_party/resample2d/resample2d.py b/imaginaire/third_party/resample2d/resample2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbdea3fa9941894090aa124adda1b62e1ea5e012
--- /dev/null
+++ b/imaginaire/third_party/resample2d/resample2d.py
@@ -0,0 +1,62 @@
+# flake8: noqa
+from torch.nn.modules.module import Module
+from torch.autograd import Function, Variable
+from torch.cuda.amp import autocast
+import resample2d_cuda
+
+
+class Resample2dFunction(Function):
+
+    @staticmethod
+    # def forward(ctx, input1, input2, kernel_size=1, bilinear=True):
+    def forward(ctx, input1, input2, kernel_size=1):
+        assert input1.is_contiguous()
+        assert input2.is_contiguous()
+
+        ctx.save_for_backward(input1, input2)
+        ctx.kernel_size = kernel_size
+        ctx.bilinear = True
+
+        _, d, _, _ = input1.size()
+        b, _, h, w = input2.size()
+        output = input1.new(b, d, h, w).zero_()
+
+        resample2d_cuda.forward(input1, input2, output, kernel_size)
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_output = grad_output.contiguous()
+        assert grad_output.is_contiguous()
+
+        input1, input2 = ctx.saved_tensors
+
+        grad_input1 = Variable(input1.new(input1.size()).zero_())
+        grad_input2 = Variable(input1.new(input2.size()).zero_())
+
+        # resample2d_cuda.backward(input1, input2, grad_output.data,
+        #                          grad_input1.data, grad_input2.data,
+        #                          ctx.kernel_size, ctx.bilinear)
+        resample2d_cuda.backward(input1, input2, grad_output.data,
+                                 grad_input1.data, grad_input2.data,
+                                 ctx.kernel_size)
+
+        return grad_input1, grad_input2, None, None
+
+
+class Resample2d(Module):
+
+    def __init__(self, kernel_size=1, bilinear=True):
+        super(Resample2d, self).__init__()
+        self.kernel_size = kernel_size
+        self.bilinear = bilinear
+
+    @autocast(False)
+    def forward(self, input1, input2):
+        input1, input2 = input1.float(), input2.float()
+        input1_c = input1.contiguous()
+        # return Resample2dFunction.apply(
+        #     input1_c, input2, self.kernel_size, self.bilinear)
+        return Resample2dFunction.apply(
+            input1_c, input2, self.kernel_size)
\ No newline at end of file
diff --git a/imaginaire/third_party/resample2d/setup.py b/imaginaire/third_party/resample2d/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..54c14d9743bf514e01661bd1bc307a4e3f8986fe
--- /dev/null
+++ b/imaginaire/third_party/resample2d/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+    if cuda_version >= '11.0':
+        nvcc_args.append('-gencode')
+        nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+    name='resample2d_cuda',
+    py_modules=['resample2d'],
+    ext_modules=[
+        CUDAExtension('resample2d_cuda', [
+            './src/resample2d_cuda.cc',
+            './src/resample2d_kernel.cu'
+        ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+                               'nvcc': nvcc_args})
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/imaginaire/third_party/resample2d/src/resample2d_cuda.cc b/imaginaire/third_party/resample2d/src/resample2d_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b330a06bc0f20fe82c275e9a784f7ed91faf7717
--- /dev/null
+++ b/imaginaire/third_party/resample2d/src/resample2d_cuda.cc
@@ -0,0 +1,34 @@
+#include <ATen/ATen.h>
+#include <torch/torch.h>
+
+#include "resample2d_kernel.cuh"
+
+int resample2d_cuda_forward(
+    at::Tensor& input1,
+    at::Tensor& input2, 
+    at::Tensor& output,
+    int kernel_size/*, bool bilinear*/) {
+      resample2d_kernel_forward(input1, input2, output, kernel_size/*,
+      bilinear*/);
+    return 1;
+}
+
+int resample2d_cuda_backward(
+    at::Tensor& input1, 
+    at::Tensor& input2,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1, 
+    at::Tensor& gradInput2, 
+    int kernel_size/*, bool bilinear*/) {
+        resample2d_kernel_backward(input1, input2, gradOutput, gradInput1,
+        gradInput2, kernel_size/*, bilinear*/);
+    return 1;
+}
+
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
+  m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
+}
+
diff --git a/imaginaire/third_party/resample2d/src/resample2d_kernel.cu b/imaginaire/third_party/resample2d/src/resample2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..654ca8e417b6d22ff623d72d19e24997a2db284c
--- /dev/null
+++ b/imaginaire/third_party/resample2d/src/resample2d_kernel.cu
@@ -0,0 +1,328 @@
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#define CUDA_NUM_THREADS 512 
+#define THREADS_PER_BLOCK 64 
+
+#define DIM0(TENSOR) ((TENSOR).x)
+#define DIM1(TENSOR) ((TENSOR).y)
+#define DIM2(TENSOR) ((TENSOR).z)
+#define DIM3(TENSOR) ((TENSOR).w)
+
+#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
+
+template <typename scalar_t>
+__global__ void kernel_resample2d_update_output(const int n, 
+                                               const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+                                               const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 
+                                               scalar_t* __restrict__ output,
+                                               const long4 output_size, const
+                                               long4 output_stride, int
+                                               kernel_size/*, bool bilinear*/) {
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+    bool bilinear = true;
+    if (index >= n) {
+        return;
+    }
+
+    scalar_t val = 0.0f;
+
+    int dim_b = DIM0(output_size);
+    int dim_c = DIM1(output_size);
+    int dim_h = DIM2(output_size);
+    int dim_w = DIM3(output_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+    int dim_hw  = dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int c = ( index / dim_hw )  % dim_c;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+    scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+    scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+    scalar_t xf = static_cast<scalar_t>(x) + dx;
+    scalar_t yf = static_cast<scalar_t>(y) + dy;
+    scalar_t alpha = xf - floor(xf); // alpha
+    scalar_t beta = yf - floor(yf); // beta
+
+    if (bilinear) {
+        int xL = max(min( int (floor(xf)),    dim_w-1), 0);
+        int xR = max(min( int (floor(xf)+1), dim_w -1), 0);
+        int yT = max(min( int (floor(yf)),    dim_h-1), 0);
+        int yB = max(min( int (floor(yf)+1),  dim_h-1), 0);
+
+        for (int fy = 0; fy < kernel_size; fy += 1) {
+            for (int fx = 0; fx < kernel_size; fx += 1) {
+                val += static_cast<float>((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx));
+                val += static_cast<float>((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx));
+                val += static_cast<float>((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx));
+                val += static_cast<float>((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx));
+            }
+        }
+
+        output[index] = val;
+    }
+    else {
+        int xN = max(min( int (floor(xf + 0.5)), dim_w - 1), 0);
+        int yN = max(min( int (floor(yf + 0.5)), dim_h - 1), 0);
+
+        output[index] = static_cast<float> ( DIM3_INDEX(input1, b, c, yN, xN) );
+    }
+
+}
+
+
+template <typename scalar_t>
+__global__ void kernel_resample2d_backward_input1(
+    const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+    const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
+    const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+    scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4
+    gradInput_stride, int kernel_size/*, bool bilinear*/) {
+
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+    bool bilinear = true;
+    if (index >= n) {
+        return;
+    }
+
+    int dim_b = DIM0(gradOutput_size);
+    int dim_c = DIM1(gradOutput_size);
+    int dim_h = DIM2(gradOutput_size);
+    int dim_w = DIM3(gradOutput_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+    int dim_hw  = dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int c = ( index / dim_hw )  % dim_c;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+    scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+    scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+    scalar_t xf = static_cast<scalar_t>(x) + dx;
+    scalar_t yf = static_cast<scalar_t>(y) + dy;
+    scalar_t alpha = xf - int(xf); // alpha
+    scalar_t beta = yf - int(yf); // beta
+
+    int idim_h = DIM2(input1_size);
+    int idim_w = DIM3(input1_size);
+
+    int xL = max(min( int (floor(xf)),    idim_w-1), 0);
+    int xR = max(min( int (floor(xf)+1), idim_w -1), 0);
+    int yT = max(min( int (floor(yf)),    idim_h-1), 0);
+    int yB = max(min( int (floor(yf)+1),  idim_h-1), 0);
+
+    for (int fy = 0; fy < kernel_size; fy += 1) {
+        for (int fx = 0; fx < kernel_size; fx += 1) {
+            atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+            atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)),   (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+            atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)),   (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+            atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)),     (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+        }
+    }
+
+}
+
+template <typename scalar_t>
+__global__ void kernel_resample2d_backward_input2(
+    const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+    const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
+    const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+    scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4
+    gradInput_stride, int kernel_size/*, bool bilinear*/) {
+
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+    bool bilinear = true;
+    if (index >= n) {
+        return;
+    }
+
+    scalar_t output = 0.0;
+    int kernel_rad = (kernel_size - 1)/2;
+
+    int dim_b = DIM0(gradInput_size);
+    int dim_c = DIM1(gradInput_size);
+    int dim_h = DIM2(gradInput_size);
+    int dim_w = DIM3(gradInput_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+    int dim_hw  = dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int c = ( index / dim_hw )  % dim_c;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+    int odim_c = DIM1(gradOutput_size);
+
+    scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+    scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+    scalar_t xf = static_cast<scalar_t>(x) + dx;
+    scalar_t yf = static_cast<scalar_t>(y) + dy;
+
+    int xL = max(min( int (floor(xf)),    dim_w-1), 0);
+    int xR = max(min( int (floor(xf)+1), dim_w -1), 0);
+    int yT = max(min( int (floor(yf)),    dim_h-1), 0);
+    int yB = max(min( int (floor(yf)+1),  dim_h-1), 0);
+    
+    if (c % 2) {
+        float gamma = 1 - (xf - floor(xf)); // alpha
+        for (int i = 0; i <= 2*kernel_rad; ++i) {
+            for (int j = 0; j <= 2*kernel_rad; ++j) {
+                for (int ch = 0; ch < odim_c; ++ch) {
+                    output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i));
+                    output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i));
+                    output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i));
+                    output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i));
+                }
+            }
+        }
+    }
+    else {
+        float gamma = 1 - (yf - floor(yf)); // alpha
+        for (int i = 0; i <= 2*kernel_rad; ++i) {
+            for (int j = 0; j <= 2*kernel_rad; ++j) {
+                for (int ch = 0; ch < odim_c; ++ch) {
+                    output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i));
+                    output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i));
+                    output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i));
+                    output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i));
+                }
+            }
+        }
+
+    }
+
+    gradInput[index] = output;
+
+}
+
+void resample2d_kernel_forward(
+    at::Tensor& input1, 
+    at::Tensor& input2,
+    at::Tensor& output, 
+    int kernel_size/*,
+    bool bilinear*/) {
+
+    int n = output.numel();
+
+    const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+    const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+    const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
+    const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
+
+    const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+    const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+    // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF
+//    AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] {
+
+        kernel_resample2d_update_output<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+            n,
+            input1.data<float>(),
+            input1_size,
+            input1_stride, 
+            input2.data<float>(),
+            input2_size,
+            input2_stride,
+            output.data<float>(),
+            output_size,
+            output_stride,
+            kernel_size/*,
+            bilinear*/);
+
+//    }));
+
+        // TODO: ATen-equivalent check
+
+       //    THCudaCheck(cudaGetLastError());
+
+}
+
+void resample2d_kernel_backward(
+    at::Tensor& input1,
+    at::Tensor& input2,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1,
+    at::Tensor& gradInput2,
+    int kernel_size/*,
+    bool bilinear*/) {
+
+    int n = gradOutput.numel();
+
+    const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+    const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+    const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
+    const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
+
+    const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
+    const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
+
+    const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
+    const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
+
+//    AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] {
+
+        kernel_resample2d_backward_input1<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+            n, 
+            input1.data<float>(), 
+            input1_size,
+            input1_stride,
+            input2.data<float>(),
+            input2_size, 
+            input2_stride,
+            gradOutput.data<float>(),
+            gradOutput_size,
+            gradOutput_stride,
+            gradInput1.data<float>(),
+            gradInput1_size,
+            gradInput1_stride, 
+            kernel_size/*,
+            bilinear*/
+        );
+
+//    }));
+
+    const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3));
+    const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3));
+
+    n = gradInput2.numel();
+
+//    AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] {
+
+
+        kernel_resample2d_backward_input2<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+            n, 
+            input1.data<float>(), 
+            input1_size, 
+            input1_stride,
+            input2.data<float>(), 
+            input2_size,
+            input2_stride,
+            gradOutput.data<float>(),
+            gradOutput_size,
+            gradOutput_stride,
+            gradInput2.data<float>(),
+            gradInput2_size,
+            gradInput2_stride,
+            kernel_size/*,
+            bilinear*/
+       );
+
+//    }));
+
+    // TODO: Use the ATen equivalent to get last error
+
+    //    THCudaCheck(cudaGetLastError());
+
+}
diff --git a/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh b/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..3a815269a562e762cd7bd0c73af21d468d4eb2fd
--- /dev/null
+++ b/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <ATen/ATen.h>
+
+void resample2d_kernel_forward(
+    at::Tensor& input1,
+    at::Tensor& input2,
+    at::Tensor& output,
+    int kernel_size/*,
+    bool bilinear*/);
+
+void resample2d_kernel_backward(
+    at::Tensor& input1,
+    at::Tensor& input2,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1, 
+    at::Tensor& gradInput2, 
+    int kernel_size/*,
+    bool bilinear*/);
\ No newline at end of file
diff --git a/imaginaire/third_party/upfirdn2d/__init__.py b/imaginaire/third_party/upfirdn2d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c92bf9c45932a8578f64ef83dc0b067ebd27ca0
--- /dev/null
+++ b/imaginaire/third_party/upfirdn2d/__init__.py
@@ -0,0 +1,3 @@
+from .upfirdn2d import BlurUpsample, BlurDownsample, Blur
+
+__all__ = ['BlurUpsample', 'BlurDownsample', 'Blur']
diff --git a/imaginaire/third_party/upfirdn2d/__pycache__/__init__.cpython-38.pyc b/imaginaire/third_party/upfirdn2d/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eaa606ee083d81fe1526c72e9939d485d771c19a
Binary files /dev/null and b/imaginaire/third_party/upfirdn2d/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/third_party/upfirdn2d/__pycache__/upfirdn2d.cpython-38.pyc b/imaginaire/third_party/upfirdn2d/__pycache__/upfirdn2d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf77a6379ef867beeb0dd391f94950ab0d4b6c91
Binary files /dev/null and b/imaginaire/third_party/upfirdn2d/__pycache__/upfirdn2d.cpython-38.pyc differ
diff --git a/imaginaire/third_party/upfirdn2d/setup.py b/imaginaire/third_party/upfirdn2d/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69207b2daaeaebaa81e5dfe3c1001656fe4248a
--- /dev/null
+++ b/imaginaire/third_party/upfirdn2d/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+    if cuda_version >= '11.0':
+        nvcc_args.append('-gencode')
+        nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+    name='upfirdn2d_cuda',
+    py_modules=['upfirdn2d'],
+    ext_modules=[
+        CUDAExtension('upfirdn2d_cuda', [
+            './src/upfirdn2d_cuda.cc',
+            './src/upfirdn2d_cuda_kernel.cu'
+        ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+                               'nvcc': nvcc_args})
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..65df7a9ad78e4f6f7560feed79048983f60e8add
--- /dev/null
+++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc
@@ -0,0 +1,103 @@
+// Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include "upfirdn2d_cuda.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d_cuda(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+    // Validate arguments.
+    TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+    TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+    TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+    TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+    TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+    TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+    TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+    TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+    TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+    TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+    // Create output tensor.
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+    int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+    int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+    TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+    torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+    TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+
+    // Initialize CUDA kernel parameters.
+    upfirdn2d_kernel_params p;
+    p.x             = x.data_ptr();
+    p.f             = f.data_ptr<float>();
+    p.y             = y.data_ptr();
+    p.up            = make_int2(upx, upy);
+    p.down          = make_int2(downx, downy);
+    p.pad0          = make_int2(padx0, pady0);
+    p.flip          = (flip) ? 1 : 0;
+    p.gain          = gain;
+    p.inSize        = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+    p.inStride      = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+    p.filterSize    = make_int2((int)f.size(1), (int)f.size(0));
+    p.filterStride  = make_int2((int)f.stride(1), (int)f.stride(0));
+    p.outSize       = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+    p.outStride     = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+    p.sizeMajor     = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+    p.sizeMinor     = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+    // Choose CUDA kernel.
+    upfirdn2d_kernel_spec spec;
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda_kernel", [&]
+    {
+        spec = choose_upfirdn2d_kernel<scalar_t>(p);
+    });
+
+    // Set looping options.
+    p.loopMajor     = (p.sizeMajor - 1) / 16384 + 1;
+    p.loopMinor     = spec.loopMinor;
+    p.loopX         = spec.loopX;
+    p.launchMinor   = (p.sizeMinor - 1) / p.loopMinor + 1;
+    p.launchMajor   = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+    // Compute grid size.
+    dim3 blockSize, gridSize;
+    if (spec.tileOutW < 0) // large
+    {
+        blockSize = dim3(4, 32, 1);
+        gridSize = dim3(
+            ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+            (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+            p.launchMajor);
+    }
+    else // small
+    {
+        blockSize = dim3(256, 1, 1);
+        gridSize = dim3(
+            ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+            (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+            p.launchMajor);
+    }
+
+    // Launch CUDA kernel.
+    void* args[] = {&p};
+    AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+    return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+    m.def("upfirdn2d_cuda", &upfirdn2d_cuda);
+}
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd
--- /dev/null
+++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include <cuda_runtime.h>
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct upfirdn2d_kernel_params
+{
+    const void*     x;
+    const float*    f;
+    void*           y;
+
+    int2            up;
+    int2            down;
+    int2            pad0;
+    int             flip;
+    float           gain;
+
+    int4            inSize;         // [width, height, channel, batch]
+    int4            inStride;
+    int2            filterSize;     // [width, height]
+    int2            filterStride;
+    int4            outSize;        // [width, height, channel, batch]
+    int4            outStride;
+    int             sizeMinor;
+    int             sizeMajor;
+
+    int             loopMinor;
+    int             loopMajor;
+    int             loopX;
+    int             launchMinor;
+    int             launchMajor;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct upfirdn2d_kernel_spec
+{
+    void*   kernel;
+    int     tileOutW;
+    int     tileOutH;
+    int     loopMinor;
+    int     loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d7f8938f7ac1220d934fe6a357de543a452445e4
--- /dev/null
+++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu
@@ -0,0 +1,350 @@
+// Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto.  Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include <c10/util/Half.h>
+#include "upfirdn2d_cuda.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template <class T> struct InternalType;
+template <> struct InternalType<double>     { typedef double scalar_t; };
+template <> struct InternalType<float>      { typedef float  scalar_t; };
+template <> struct InternalType<c10::Half>  { typedef float  scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+    int t = 1 - a / b;
+    return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+    typedef typename InternalType<T>::scalar_t scalar_t;
+
+    // Calculate thread index.
+    int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+    int outY = minorBase / p.launchMinor;
+    minorBase -= outY * p.launchMinor;
+    int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+    int majorBase = blockIdx.z * p.loopMajor;
+    if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+        return;
+
+    // Setup Y receptive field.
+    int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+    int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+    int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+    int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+    if (p.flip)
+        filterY = p.filterSize.y - 1 - filterY;
+
+    // Loop over major, minor, and X.
+    for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+    for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+    {
+        int nc = major * p.sizeMinor + minor;
+        int n = nc / p.inSize.z;
+        int c = nc - n * p.inSize.z;
+        for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+        {
+            // Setup X receptive field.
+            int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+            int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+            int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+            int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+            if (p.flip)
+                filterX = p.filterSize.x - 1 - filterX;
+
+            // Initialize pointers.
+            const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+            const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+            int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+            int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+            // Inner loop.
+            scalar_t v = 0;
+            for (int y = 0; y < h; y++)
+            {
+                for (int x = 0; x < w; x++)
+                {
+                    v += (scalar_t)(*xp) * (scalar_t)(*fp);
+                    xp += p.inStride.x;
+                    fp += filterStepX;
+                }
+                xp += p.inStride.y - w * p.inStride.x;
+                fp += filterStepY - w * filterStepX;
+            }
+
+            // Store result.
+            v *= p.gain;
+            ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+        }
+    }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+    typedef typename InternalType<T>::scalar_t scalar_t;
+    const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+    const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+    __shared__ volatile scalar_t sf[filterH][filterW];
+    __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+    // Calculate tile index.
+    int minorBase = blockIdx.x;
+    int tileOutY = minorBase / p.launchMinor;
+    minorBase -= tileOutY * p.launchMinor;
+    minorBase *= loopMinor;
+    tileOutY *= tileOutH;
+    int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+    int majorBase = blockIdx.z * p.loopMajor;
+    if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+        return;
+
+    // Load filter (flipped).
+    for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+    {
+        int fy = tapIdx / filterW;
+        int fx = tapIdx - fy * filterW;
+        scalar_t v = 0;
+        if (fx < p.filterSize.x & fy < p.filterSize.y)
+        {
+            int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+            int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+            v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+        }
+        sf[fy][fx] = v;
+    }
+
+    // Loop over major and X.
+    for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+    {
+        int baseNC = major * p.sizeMinor + minorBase;
+        int n = baseNC / p.inSize.z;
+        int baseC = baseNC - n * p.inSize.z;
+        for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+        {
+            // Load input pixels.
+            int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+            int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+            int tileInX = floor_div(tileMidX, upx);
+            int tileInY = floor_div(tileMidY, upy);
+            __syncthreads();
+            for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+            {
+                int relC = inIdx;
+                int relInX = relC / loopMinor;
+                int relInY = relInX / tileInW;
+                relC -= relInX * loopMinor;
+                relInX -= relInY * tileInW;
+                int c = baseC + relC;
+                int inX = tileInX + relInX;
+                int inY = tileInY + relInY;
+                scalar_t v = 0;
+                if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+                    v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+                sx[relInY][relInX][relC] = v;
+            }
+
+            // Loop over output pixels.
+            __syncthreads();
+            for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+            {
+                int relC = outIdx;
+                int relOutX = relC / loopMinor;
+                int relOutY = relOutX / tileOutW;
+                relC -= relOutX * loopMinor;
+                relOutX -= relOutY * tileOutW;
+                int c = baseC + relC;
+                int outX = tileOutX + relOutX;
+                int outY = tileOutY + relOutY;
+
+                // Setup receptive field.
+                int midX = tileMidX + relOutX * downx;
+                int midY = tileMidY + relOutY * downy;
+                int inX = floor_div(midX, upx);
+                int inY = floor_div(midY, upy);
+                int relInX = inX - tileInX;
+                int relInY = inY - tileInY;
+                int filterX = (inX + 1) * upx - midX - 1; // flipped
+                int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+                // Inner loop.
+                if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+                {
+                    scalar_t v = 0;
+                    #pragma unroll
+                    for (int y = 0; y < filterH / upy; y++)
+                        #pragma unroll
+                        for (int x = 0; x < filterW / upx; x++)
+                            v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+                    v *= p.gain;
+                    ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+                }
+            }
+        }
+    }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+    int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+
+    upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
+    if (s == 1)           spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
+
+    if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+    {
+        if (fx <= 7  && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7,  64,16,1>, 64,16,1, 1};
+        if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6,  64,16,1>, 64,16,1, 1};
+        if (fx <= 5  && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5,  64,16,1>, 64,16,1, 1};
+        if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  64,16,1>, 64,16,1, 1};
+        if (fx <= 3  && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3,  64,16,1>, 64,16,1, 1};
+        if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
+        if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
+        if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
+        if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
+        if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1,  128,8,1>, 128,8,1, 1};
+        if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
+        if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
+        if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
+        if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
+        if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8,  32,32,1>, 32,32,1, 1};
+    }
+    if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+    {
+        if (fx <= 7  && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7,  16,16,8>,  16,16,8,  1};
+        if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  16,16,8>,  16,16,8,  1};
+        if (fx <= 5  && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  16,16,8>,  16,16,8,  1};
+        if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  16,16,8>,  16,16,8,  1};
+        if (fx <= 3  && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  16,16,8>,  16,16,8,  1};
+        if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
+        if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
+        if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
+        if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
+        if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1,  128,1,16>, 128,1,16, 1};
+        if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
+        if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
+        if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
+        if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
+        if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8,  1,128,16>, 1,128,16, 1};
+    }
+    if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+    {
+        if (fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
+        if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
+        if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
+        if (fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
+    }
+    if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+    {
+        if (fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
+        if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
+        if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
+        if (fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
+    }
+    if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+    {
+        if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
+        if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
+        if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
+        if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
+        if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1,  128,8,1>, 128,8,1, 1};
+    }
+    if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+    {
+        if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
+        if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
+        if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
+        if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
+        if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1,  128,1,16>, 128,1,16, 1};
+    }
+    if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+    {
+        if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
+        if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
+        if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
+        if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
+        if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8,  32,32,1>, 32,32,1, 1};
+    }
+    if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+    {
+        if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
+        if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
+        if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
+        if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
+        if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8,  1,128,16>, 1,128,16, 1};
+    }
+    if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
+    {
+        if (fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8,  32,8,1>, 32,8,1, 1};
+        if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6,  32,8,1>, 32,8,1, 1};
+        if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4,  32,8,1>, 32,8,1, 1};
+        if (fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2,  32,8,1>, 32,8,1, 1};
+    }
+    if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
+    {
+        if (fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8,  8,8,8>, 8,8,8, 1};
+        if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6,  8,8,8>, 8,8,8, 1};
+        if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4,  8,8,8>, 8,8,8, 1};
+        if (fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2,  8,8,8>, 8,8,8, 1};
+    }
+    if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
+    {
+        if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
+        if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
+        if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
+        if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
+        if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1,  64,8,1>, 64,8,1, 1};
+    }
+    if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
+    {
+        if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
+        if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
+        if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
+        if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
+        if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1,  64,1,8>, 64,1,8, 1};
+    }
+    if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
+    {
+        if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
+        if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
+        if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
+        if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
+        if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8,  32,16,1>, 32,16,1, 1};
+    }
+    if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
+    {
+        if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
+        if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
+        if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
+        if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
+        if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8,  1,64,8>, 1,64,8, 1};
+    }
+    return spec;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double>   (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float>    (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/upfirdn2d/upfirdn2d.py b/imaginaire/third_party/upfirdn2d/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..8548efe56653d6d8083f68d6e6617ba84b398d1e
--- /dev/null
+++ b/imaginaire/third_party/upfirdn2d/upfirdn2d.py
@@ -0,0 +1,471 @@
+# flake8: noqa
+# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient resampling of 2D images."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+import upfirdn2d_cuda
+
+
+def _parse_scaling(scaling):
+    if isinstance(scaling, int):
+        scaling = [scaling, scaling]
+    assert isinstance(scaling, (list, tuple))
+    assert all(isinstance(x, int) for x in scaling)
+    sx, sy = scaling
+    assert sx >= 1 and sy >= 1
+    return sx, sy
+
+
+def _parse_padding(padding):
+    if isinstance(padding, int):
+        padding = [padding, padding]
+    assert isinstance(padding, (list, tuple))
+    assert all(isinstance(x, int) for x in padding)
+    if len(padding) == 2:
+        padx, pady = padding
+        padding = [padx, padx, pady, pady]
+    padx0, padx1, pady0, pady1 = padding
+    return padx0, padx1, pady0, pady1
+
+
+def _get_filter_size(f):
+    if f is None:
+        return 1, 1
+    assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+    fw = f.shape[-1]
+    fh = f.shape[0]
+    assert fw >= 1 and fh >= 1
+    return fw, fh
+
+
+class BlurUpsample(nn.Module):
+    def __init__(self,
+                 kernel=(1, 3, 3, 1),
+                 factor=2,
+                 padding_mode='zeros'):
+        super().__init__()
+        p = len(kernel)
+        px0 = (p + factor - 1) // 2
+        px1 = (p - factor) // 2
+        py0 = (p + factor - 1) // 2
+        py1 = (p - factor) // 2
+
+        self.pad = [px0, px1, py0, py1]
+        self.factor = factor
+        self.register_buffer('kernel', setup_filter(kernel))
+        self.kernel_1d = kernel
+        self.padding_mode = padding_mode
+
+    def forward(self, x):
+        if self.padding_mode != 'zeros':
+            x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode)
+            out = upfirdn2d(
+                x, self.kernel, up=self.factor, gain=self.factor ** 2)
+        else:
+            out = upfirdn2d(
+                x, self.kernel, up=self.factor, padding=self.pad,
+                gain=self.factor ** 2)
+        return out
+
+    def extra_repr(self):
+        s = 'kernel={kernel_1d}, ' \
+            'padding_mode={padding_mode}, pad={pad}'
+        return s.format(**self.__dict__)
+
+
+class BlurDownsample(nn.Module):
+    def __init__(self, kernel=(1, 3, 3, 1), factor=2, padding_mode='zeros'):
+        super().__init__()
+        p = len(kernel)
+        px0 = (p - factor + 1) // 2
+        px1 = (p - factor) // 2
+        py0 = (p - factor + 1) // 2
+        py1 = (p - factor) // 2
+
+        self.pad = [px0, px1, py0, py1]
+        self.factor = factor
+        self.register_buffer('kernel', setup_filter(kernel))
+        self.kernel_1d = kernel
+        self.padding_mode = padding_mode
+
+    def forward(self, x):
+        if self.padding_mode != 'zeros':
+            x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode)
+            out = upfirdn2d(x, self.kernel, down=self.factor)
+        else:
+            out = upfirdn2d(x, self.kernel, down=self.factor, padding=self.pad)
+        return out
+
+    def extra_repr(self):
+        s = 'kernel={kernel_1d}, ' \
+            'padding_mode={padding_mode}, pad={pad}'
+        return s.format(**self.__dict__)
+
+
+class Blur(nn.Module):
+    def __init__(self,
+                 kernel=(1, 3, 3, 1),
+                 pad=0,
+                 padding_mode='zeros'):
+        super().__init__()
+        self.register_buffer('kernel', setup_filter(kernel))
+        self.kernel_1d = kernel
+        self.padding_mode = padding_mode
+        self.pad = pad
+
+    def forward(self, x):
+        if self.padding_mode != 'zeros':
+            x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode)
+            out = upfirdn2d(x, self.kernel)
+        else:
+            out = upfirdn2d(x, self.kernel, padding=self.pad)
+        return out
+
+    def extra_repr(self):
+        s = 'kernel={kernel_1d}, ' \
+            'padding_mode={padding_mode}, pad={pad}'
+        return s.format(**self.__dict__)
+
+
+# ----------------------------------------------------------------------------
+
+def setup_filter(f, device=torch.device('cpu'), normalize=True,
+                 flip_filter=False, gain=1, separable=None):
+    r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
+
+    Args:
+        f:           Torch tensor, numpy array, or python list of the shape
+                     `[filter_height, filter_width]` (non-separable),
+                     `[filter_taps]` (separable),
+                     `[]` (impulse), or
+                     `None` (identity).
+        device:      Result device (default: cpu).
+        normalize:   Normalize the filter so that it retains the magnitude
+                     for constant input signal (DC)? (default: True).
+        flip_filter: Flip the filter? (default: False).
+        gain:        Overall scaling factor for signal magnitude (default: 1).
+        separable:   Return a separable filter? (default: select automatically).
+
+    Returns:
+        Float32 tensor of the shape
+        `[filter_height, filter_width]` (non-separable) or
+        `[filter_taps]` (separable).
+    """
+    # Validate.
+    if f is None:
+        f = 1
+    f = torch.as_tensor(f, dtype=torch.float32)
+    assert f.ndim in [0, 1, 2]
+    assert f.numel() > 0
+    if f.ndim == 0:
+        f = f[np.newaxis]
+
+    # Separable?
+    if separable is None:
+        separable = (f.ndim == 1 and f.numel() >= 8)
+    if f.ndim == 1 and not separable:
+        f = f.ger(f)
+    assert f.ndim == (1 if separable else 2)
+
+    # Apply normalize, flip, gain, and device.
+    if normalize:
+        f /= f.sum()
+    if flip_filter:
+        f = f.flip(list(range(f.ndim)))
+    f = f * (gain ** (f.ndim / 2))
+    f = f.to(device=device)
+    return f
+
+
+# ----------------------------------------------------------------------------
+
+def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
+    r"""Pad, upsample, filter, and downsample a batch of 2D images.
+
+    Performs the following sequence of operations for each channel:
+
+    1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+    2. Pad the image with the specified number of zeros on each side (`padding`).
+       Negative padding corresponds to cropping the image.
+
+    3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
+       so that the footprint of all output pixels lies within the input image.
+
+    4. Downsample the image by keeping every Nth pixel (`down`).
+
+    This sequence of operations bears close resemblance to scipy.signal.upfirdn().
+    The fused op is considerably more efficient than performing the same calculation
+    using standard PyTorch ops. It supports gradients of arbitrary order.
+
+    Args:
+        x:           Float32/float64/float16 input tensor of the shape
+                     `[batch_size, num_channels, in_height, in_width]`.
+        f:           Float32 FIR filter of the shape
+                     `[filter_height, filter_width]` (non-separable),
+                     `[filter_taps]` (separable), or
+                     `None` (identity).
+        up:          Integer upsampling factor. Can be a single int or a list/tuple
+                     `[x, y]` (default: 1).
+        down:        Integer downsampling factor. Can be a single int or a list/tuple
+                     `[x, y]` (default: 1).
+        padding:     Padding with respect to the upsampled image. Can be a single number
+                     or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+                     (default: 0).
+        flip_filter: False = convolution, True = correlation (default: False).
+        gain:        Overall scaling factor for signal magnitude (default: 1).
+        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+    Returns:
+        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+    """
+    assert isinstance(x, torch.Tensor)
+    assert impl in ['ref', 'cuda']
+    if impl == 'cuda' and x.device.type == 'cuda':
+        return _upfirdn2d_cuda(up=up, down=down, padding=padding,
+                               flip_filter=flip_filter, gain=gain).apply(x, f)
+    return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
+
+
+# ----------------------------------------------------------------------------
+
+def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
+    """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
+    """
+    # Validate arguments.
+    assert isinstance(x, torch.Tensor) and x.ndim == 4
+    if f is None:
+        f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+    assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+    assert f.dtype == torch.float32 and not f.requires_grad
+    batch_size, num_channels, in_height, in_width = x.shape
+    upx, upy = _parse_scaling(up)
+    downx, downy = _parse_scaling(down)
+    padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+    # Upsample by inserting zeros.
+    x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
+    x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
+    x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
+
+    # Pad or crop.
+    x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0),
+                                    max(pady1, 0)])
+    x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0),
+        max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
+
+    # Setup filter.
+    f = f * (gain ** (f.ndim / 2))
+    f = f.to(x.dtype)
+    if not flip_filter:
+        f = f.flip(list(range(f.ndim)))
+
+    # Convolve with the filter.
+    f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
+    if f.ndim == 4:
+        x = F.conv2d(input=x, weight=f, groups=num_channels)
+    else:
+        x = F.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
+        x = F.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
+
+    # Downsample by throwing away pixels.
+    x = x[:, :, ::downy, ::downx]
+    return x
+
+
+# ----------------------------------------------------------------------------
+
+_upfirdn2d_cuda_cache = dict()
+
+
+def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
+    """Fast CUDA implementation of `upfirdn2d()` using custom ops.
+    """
+    # Parse arguments.
+    upx, upy = _parse_scaling(up)
+    downx, downy = _parse_scaling(down)
+    padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+    # Lookup from cache.
+    key = (
+        upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+    if key in _upfirdn2d_cuda_cache:
+        return _upfirdn2d_cuda_cache[key]
+
+    # Forward op.
+    class Upfirdn2dCuda(torch.autograd.Function):
+        @staticmethod
+        def forward(ctx, x, f):  # pylint: disable=arguments-differ
+            assert isinstance(x, torch.Tensor) and x.ndim == 4
+            if f is None:
+                f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+            assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+            y = x
+            if f.ndim == 2:
+                y = upfirdn2d_cuda.upfirdn2d_cuda(y, f, upx, upy, downx, downy, padx0,
+                                                  padx1, pady0, pady1, flip_filter, gain)
+            else:
+                y = upfirdn2d_cuda.upfirdn2d_cuda(y, f.unsqueeze(0), upx, 1, downx, 1,
+                                                  padx0, padx1, 0, 0, flip_filter,
+                                                  np.sqrt(gain))
+                y = upfirdn2d_cuda.upfirdn2d_cuda(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0,
+                                                  pady0, pady1, flip_filter, np.sqrt(gain))
+            ctx.save_for_backward(f)
+            ctx.x_shape = x.shape
+            return y
+
+        @staticmethod
+        def backward(ctx, dy):  # pylint: disable=arguments-differ
+            f, = ctx.saved_tensors
+            _, _, ih, iw = ctx.x_shape
+            _, _, oh, ow = dy.shape
+            fw, fh = _get_filter_size(f)
+            p = [
+                fw - padx0 - 1,
+                iw * upx - ow * downx + padx0 - upx + 1,
+                fh - pady0 - 1,
+                ih * upy - oh * downy + pady0 - upy + 1,
+            ]
+            dx = None
+            df = None
+
+            if ctx.needs_input_grad[0]:
+                dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
+
+            assert not ctx.needs_input_grad[1]
+            return dx, df
+
+    # Add to cache.
+    _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
+    return Upfirdn2dCuda
+
+
+# ----------------------------------------------------------------------------
+
+def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
+    r"""Filter a batch of 2D images using the given 2D FIR filter.
+
+    By default, the result is padded so that its shape matches the input.
+    User-specified padding is applied on top of that, with negative values
+    indicating cropping. Pixels outside the image are assumed to be zero.
+
+    Args:
+        x:           Float32/float64/float16 input tensor of the shape
+                     `[batch_size, num_channels, in_height, in_width]`.
+        f:           Float32 FIR filter of the shape
+                     `[filter_height, filter_width]` (non-separable),
+                     `[filter_taps]` (separable), or
+                     `None` (identity).
+        padding:     Padding with respect to the output. Can be a single number or a
+                     list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+                     (default: 0).
+        flip_filter: False = convolution, True = correlation (default: False).
+        gain:        Overall scaling factor for signal magnitude (default: 1).
+        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+    Returns:
+        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+    """
+    padx0, padx1, pady0, pady1 = _parse_padding(padding)
+    fw, fh = _get_filter_size(f)
+    p = [
+        padx0 + fw // 2,
+        padx1 + (fw - 1) // 2,
+        pady0 + fh // 2,
+        pady1 + (fh - 1) // 2,
+    ]
+    return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+
+# ----------------------------------------------------------------------------
+
+def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+    r"""Upsample a batch of 2D images using the given 2D FIR filter.
+
+    By default, the result is padded so that its shape is a multiple of the input.
+    User-specified padding is applied on top of that, with negative values
+    indicating cropping. Pixels outside the image are assumed to be zero.
+
+    Args:
+        x:           Float32/float64/float16 input tensor of the shape
+                     `[batch_size, num_channels, in_height, in_width]`.
+        f:           Float32 FIR filter of the shape
+                     `[filter_height, filter_width]` (non-separable),
+                     `[filter_taps]` (separable), or
+                     `None` (identity).
+        up:          Integer upsampling factor. Can be a single int or a list/tuple
+                     `[x, y]` (default: 1).
+        padding:     Padding with respect to the output. Can be a single number or a
+                     list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+                     (default: 0).
+        flip_filter: False = convolution, True = correlation (default: False).
+        gain:        Overall scaling factor for signal magnitude (default: 1).
+        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+    Returns:
+        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+    """
+    upx, upy = _parse_scaling(up)
+    padx0, padx1, pady0, pady1 = _parse_padding(padding)
+    fw, fh = _get_filter_size(f)
+    p = [
+        padx0 + (fw + upx - 1) // 2,
+        padx1 + (fw - upx) // 2,
+        pady0 + (fh + upy - 1) // 2,
+        pady1 + (fh - upy) // 2,
+    ]
+    return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl)
+
+
+# ----------------------------------------------------------------------------
+
+def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1,
+                 impl='cuda'):
+    r"""Downsample a batch of 2D images using the given 2D FIR filter.
+
+    By default, the result is padded so that its shape is a fraction of the input.
+    User-specified padding is applied on top of that, with negative values
+    indicating cropping. Pixels outside the image are assumed to be zero.
+
+    Args:
+        x:           Float32/float64/float16 input tensor of the shape
+                     `[batch_size, num_channels, in_height, in_width]`.
+        f:           Float32 FIR filter of the shape
+                     `[filter_height, filter_width]` (non-separable),
+                     `[filter_taps]` (separable), or
+                     `None` (identity).
+        down:        Integer downsampling factor. Can be a single int or a list/tuple
+                     `[x, y]` (default: 1).
+        padding:     Padding with respect to the input. Can be a single number or a
+                     list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+                     (default: 0).
+        flip_filter: False = convolution, True = correlation (default: False).
+        gain:        Overall scaling factor for signal magnitude (default: 1).
+        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+    Returns:
+        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+    """
+    downx, downy = _parse_scaling(down)
+    padx0, padx1, pady0, pady1 = _parse_padding(padding)
+    fw, fh = _get_filter_size(f)
+    p = [
+        padx0 + (fw - downx + 1) // 2,
+        padx1 + (fw - downx) // 2,
+        pady0 + (fh - downy + 1) // 2,
+        pady1 + (fh - downy) // 2,
+    ]
+    return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+# ----------------------------------------------------------------------------
diff --git a/imaginaire/trainers/__init__.py b/imaginaire/trainers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/trainers/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/trainers/base.py b/imaginaire/trainers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..461aed4b7f092ac6a376d1db7b28ef0fd646901f
--- /dev/null
+++ b/imaginaire/trainers/base.py
@@ -0,0 +1,982 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import json
+import os
+import time
+
+import torch
+import torchvision
+import wandb
+from torch.cuda.amp import GradScaler, autocast
+from tqdm import tqdm
+
+from imaginaire.utils.distributed import is_master, master_only
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.io import save_pilimage_in_jpeg
+from imaginaire.utils.meters import Meter
+from imaginaire.utils.misc import to_cuda, to_device, requires_grad, to_channels_last
+from imaginaire.utils.model_average import (calibrate_batch_norm_momentum,
+                                            reset_batch_norm)
+from imaginaire.utils.visualization import tensor2pilimage
+
+
+class BaseTrainer(object):
+    r"""Base trainer. We expect that all trainers inherit this class.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self,
+                 cfg,
+                 net_G,
+                 net_D,
+                 opt_G,
+                 opt_D,
+                 sch_G,
+                 sch_D,
+                 train_data_loader,
+                 val_data_loader):
+        super(BaseTrainer, self).__init__()
+        print('Setup trainer.')
+
+        # Initialize models and data loaders.
+        self.cfg = cfg
+        self.net_G = net_G
+        if cfg.trainer.model_average_config.enabled:
+            # Two wrappers (DDP + model average).
+            self.net_G_module = self.net_G.module.module
+        else:
+            # One wrapper (DDP)
+            self.net_G_module = self.net_G.module
+        self.val_data_loader = val_data_loader
+        self.is_inference = train_data_loader is None
+        self.net_D = net_D
+        self.opt_G = opt_G
+        self.opt_D = opt_D
+        self.sch_G = sch_G
+        self.sch_D = sch_D
+        self.train_data_loader = train_data_loader
+        if self.cfg.trainer.channels_last:
+            self.net_G = self.net_G.to(memory_format=torch.channels_last)
+            self.net_D = self.net_D.to(memory_format=torch.channels_last)
+
+        # Initialize amp.
+        if self.cfg.trainer.amp_config.enabled:
+            print("Using automatic mixed precision training.")
+        self.scaler_G = GradScaler(**vars(self.cfg.trainer.amp_config))
+        self.scaler_D = GradScaler(**vars(self.cfg.trainer.amp_config))
+        # In order to check whether the discriminator/generator has
+        # skipped the last parameter update due to gradient overflow.
+        self.last_step_count_G = 0
+        self.last_step_count_D = 0
+        self.skipped_G = False
+        self.skipped_D = False
+
+        # Initialize data augmentation policy.
+        self.aug_policy = cfg.trainer.aug_policy
+        print("Augmentation policy: {}".format(self.aug_policy))
+
+        # Initialize loss functions.
+        # All loss names have weights. Some have criterion modules.
+        # Mapping from loss names to criterion modules.
+        self.criteria = torch.nn.ModuleDict()
+        # Mapping from loss names to loss weights.
+        self.weights = dict()
+        self.losses = dict(gen_update=dict(), dis_update=dict())
+        self.gen_losses = self.losses['gen_update']
+        self.dis_losses = self.losses['dis_update']
+        self._init_loss(cfg)
+        for loss_name, loss_weight in self.weights.items():
+            print("Loss {:<20} Weight {}".format(loss_name, loss_weight))
+            if loss_name in self.criteria.keys() and \
+                    self.criteria[loss_name] is not None:
+                self.criteria[loss_name].to('cuda')
+
+        if self.is_inference:
+            # The initialization steps below can be skipped during inference.
+            return
+
+        # Initialize logging attributes.
+        self.current_iteration = 0
+        self.current_epoch = 0
+        self.start_iteration_time = None
+        self.start_epoch_time = None
+        self.elapsed_iteration_time = 0
+        self.time_iteration = None
+        self.time_epoch = None
+        self.best_fid = None
+        if self.cfg.speed_benchmark:
+            self.accu_gen_forw_iter_time = 0
+            self.accu_gen_loss_iter_time = 0
+            self.accu_gen_back_iter_time = 0
+            self.accu_gen_step_iter_time = 0
+            self.accu_gen_avg_iter_time = 0
+            self.accu_dis_forw_iter_time = 0
+            self.accu_dis_loss_iter_time = 0
+            self.accu_dis_back_iter_time = 0
+            self.accu_dis_step_iter_time = 0
+
+        # Initialize tensorboard and hparams.
+        self._init_tensorboard()
+        self._init_hparams()
+
+        # Initialize validation parameters.
+        self.val_sample_size = getattr(cfg.trainer, 'val_sample_size', 50000)
+        self.kid_num_subsets = getattr(cfg.trainer, 'kid_num_subsets', 10)
+        self.kid_subset_size = self.val_sample_size // self.kid_num_subsets
+        self.metrics_path = os.path.join(torch.hub.get_dir(), 'metrics')
+        self.best_metrics = {}
+        self.eval_networks = getattr(cfg.trainer, 'eval_network', ['clean_inception'])
+        if self.cfg.metrics_iter is None:
+            self.cfg.metrics_iter = self.cfg.snapshot_save_iter
+        if self.cfg.metrics_epoch is None:
+            self.cfg.metrics_epoch = self.cfg.snapshot_save_epoch
+
+        # AWS credentials.
+        if hasattr(cfg, 'aws_credentials_file'):
+            with open(cfg.aws_credentials_file) as fin:
+                self.credentials = json.load(fin)
+        else:
+            self.credentials = None
+
+        if 'TORCH_HOME' not in os.environ:
+            os.environ['TORCH_HOME'] = os.path.join(
+                os.environ['HOME'], ".cache")
+
+    def _init_tensorboard(self):
+        r"""Initialize the tensorboard. Different algorithms might require
+        different performance metrics. Hence, custom tensorboard
+        initialization might be necessary.
+        """
+        # Logging frequency: self.cfg.logging_iter
+        self.meters = {}
+
+        # Logging frequency: self.cfg.snapshot_save_iter
+        self.metric_meters = {}
+
+        # Logging frequency: self.cfg.image_display_iter
+        self.image_meter = Meter('images', reduce=False)
+
+    def _init_hparams(self):
+        r"""Initialize a dictionary of hyperparameters that we want to monitor
+        in the HParams dashboard in tensorBoard.
+        """
+        self.hparam_dict = {}
+
+    def _write_tensorboard(self):
+        r"""Write values to tensorboard. By default, we will log the time used
+        per iteration, time used per epoch, generator learning rate, and
+        discriminator learning rate. We will log all the losses as well as
+        custom meters.
+        """
+        # Logs that are shared by all models.
+        self._write_to_meters({'time/iteration': self.time_iteration,
+                               'time/epoch': self.time_epoch,
+                               'optim/gen_lr': self.sch_G.get_last_lr()[0],
+                               'optim/dis_lr': self.sch_D.get_last_lr()[0]},
+                              self.meters,
+                              reduce=False)
+        # Logs for loss values. Different models have different losses.
+        self._write_loss_meters()
+        # Other custom logs.
+        self._write_custom_meters()
+
+    def _write_loss_meters(self):
+        r"""Write all loss values to tensorboard."""
+        for update, losses in self.losses.items():
+            # update is 'gen_update' or 'dis_update'.
+            assert update == 'gen_update' or update == 'dis_update'
+            for loss_name, loss in losses.items():
+                if loss is not None:
+                    full_loss_name = update + '/' + loss_name
+                    if full_loss_name not in self.meters.keys():
+                        # Create a new meter if it doesn't exist.
+                        self.meters[full_loss_name] = Meter(
+                            full_loss_name, reduce=True)
+                    self.meters[full_loss_name].write(loss.item())
+
+    def _write_custom_meters(self):
+        r"""Dummy member function to be overloaded by the child class.
+        In the child class, you can write down whatever you want to track.
+        """
+        pass
+
+    @staticmethod
+    def _write_to_meters(data, meters, reduce=True):
+        r"""Write values to meters."""
+        if reduce or is_master():
+            for key, value in data.items():
+                if key not in meters:
+                    meters[key] = Meter(key, reduce=reduce)
+                meters[key].write(value)
+
+    def _flush_meters(self, meters):
+        r"""Flush all meters using the current iteration."""
+        for meter in meters.values():
+            meter.flush(self.current_iteration)
+
+    def _pre_save_checkpoint(self):
+        r"""Implement the things you want to do before saving a checkpoint.
+        For example, you can compute the K-mean features (pix2pixHD) before
+        saving the model weights to a checkpoint.
+        """
+        pass
+
+    def save_checkpoint(self, current_epoch, current_iteration):
+        r"""Save network weights, optimizer parameters, scheduler parameters
+        to a checkpoint.
+        """
+        self._pre_save_checkpoint()
+        _save_checkpoint(self.cfg,
+                         self.net_G, self.net_D,
+                         self.opt_G, self.opt_D,
+                         self.sch_G, self.sch_D,
+                         current_epoch, current_iteration)
+
+    def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True):
+        r"""Load network weights, optimizer parameters, scheduler parameters
+        from a checkpoint.
+
+        Args:
+            cfg (obj): Global configuration.
+            checkpoint_path (str): Path to the checkpoint.
+            resume (bool or None): If not ``None``, will determine whether or
+                not to load optimizers in addition to network weights.
+        """
+        if os.path.exists(checkpoint_path):
+            # If checkpoint_path exists, we will load its weights to
+            # initialize our network.
+            if resume is None:
+                resume = False
+        elif os.path.exists(os.path.join(cfg.logdir, 'latest_checkpoint.txt')):
+            # This is for resuming the training from the previously saved
+            # checkpoint.
+            fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt')
+            with open(fn, 'r') as f:
+                line = f.read().splitlines()
+            checkpoint_path = os.path.join(cfg.logdir, line[0].split(' ')[-1])
+            if resume is None:
+                resume = True
+        else:
+            # checkpoint not found and not specified. We will train
+            # everything from scratch.
+            current_epoch = 0
+            current_iteration = 0
+            print('No checkpoint found.')
+            resume = False
+            return resume, current_epoch, current_iteration
+        # Load checkpoint
+        checkpoint = torch.load(
+            checkpoint_path, map_location=lambda storage, loc: storage)
+        current_epoch = 0
+        current_iteration = 0
+        if resume:
+            self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume)
+            if not self.is_inference:
+                self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume)
+                if 'opt_G' in checkpoint:
+                    current_epoch = checkpoint['current_epoch']
+                    current_iteration = checkpoint['current_iteration']
+                    self.opt_G.load_state_dict(checkpoint['opt_G'])
+                    self.opt_D.load_state_dict(checkpoint['opt_D'])
+                    if load_sch:
+                        self.sch_G.load_state_dict(checkpoint['sch_G'])
+                        self.sch_D.load_state_dict(checkpoint['sch_D'])
+                    else:
+                        if self.cfg.gen_opt.lr_policy.iteration_mode:
+                            self.sch_G.last_epoch = current_iteration
+                        else:
+                            self.sch_G.last_epoch = current_epoch
+                        if self.cfg.dis_opt.lr_policy.iteration_mode:
+                            self.sch_D.last_epoch = current_iteration
+                        else:
+                            self.sch_D.last_epoch = current_epoch
+                    print('Load from: {}'.format(checkpoint_path))
+                else:
+                    print('Load network weights only.')
+        else:
+            try:
+                self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume)
+                if 'net_D' in checkpoint:
+                    self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume)
+            except Exception:
+                if self.cfg.trainer.model_average_config.enabled:
+                    net_G_module = self.net_G.module.module
+                else:
+                    net_G_module = self.net_G.module
+                if hasattr(net_G_module, 'load_pretrained_network'):
+                    net_G_module.load_pretrained_network(self.net_G, checkpoint['net_G'])
+                    print('Load generator weights only.')
+                else:
+                    raise ValueError('Checkpoint cannot be loaded.')
+
+        print('Done with loading the checkpoint.')
+        return resume, current_epoch, current_iteration
+
+    def start_of_epoch(self, current_epoch):
+        r"""Things to do before an epoch.
+
+        Args:
+            current_epoch (int): Current number of epoch.
+        """
+        self._start_of_epoch(current_epoch)
+        self.current_epoch = current_epoch
+        self.start_epoch_time = time.time()
+
+    def start_of_iteration(self, data, current_iteration):
+        r"""Things to do before an iteration.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_iteration (int): Current number of iteration.
+        """
+        data = self._start_of_iteration(data, current_iteration)
+        data = to_cuda(data)
+        if self.cfg.trainer.channels_last:
+            data = to_channels_last(data)
+        self.current_iteration = current_iteration
+        if not self.is_inference:
+            self.net_D.train()
+        self.net_G.train()
+        # torch.cuda.synchronize()
+        self.start_iteration_time = time.time()
+        return data
+
+    def end_of_iteration(self, data, current_epoch, current_iteration):
+        r"""Things to do after an iteration.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_epoch (int): Current number of epoch.
+            current_iteration (int): Current number of iteration.
+        """
+        self.current_iteration = current_iteration
+        self.current_epoch = current_epoch
+        # Update the learning rate policy for the generator if operating in the
+        # iteration mode.
+        if self.cfg.gen_opt.lr_policy.iteration_mode:
+            self.sch_G.step()
+        # Update the learning rate policy for the discriminator if operating in
+        # the iteration mode.
+        if self.cfg.dis_opt.lr_policy.iteration_mode:
+            self.sch_D.step()
+
+        # Accumulate time
+        # torch.cuda.synchronize()
+        self.elapsed_iteration_time += time.time() - self.start_iteration_time
+        # Logging.
+        if current_iteration % self.cfg.logging_iter == 0:
+            ave_t = self.elapsed_iteration_time / self.cfg.logging_iter
+            self.time_iteration = ave_t
+            print('Iteration: {}, average iter time: '
+                  '{:6f}.'.format(current_iteration, ave_t))
+            self.elapsed_iteration_time = 0
+
+            if self.cfg.speed_benchmark:
+                # Below code block only needed when analyzing computation
+                # bottleneck.
+                print('\tGenerator FWD time {:6f}'.format(
+                    self.accu_gen_forw_iter_time / self.cfg.logging_iter))
+                print('\tGenerator LOS time {:6f}'.format(
+                    self.accu_gen_loss_iter_time / self.cfg.logging_iter))
+                print('\tGenerator BCK time {:6f}'.format(
+                    self.accu_gen_back_iter_time / self.cfg.logging_iter))
+                print('\tGenerator STP time {:6f}'.format(
+                    self.accu_gen_step_iter_time / self.cfg.logging_iter))
+                print('\tGenerator AVG time {:6f}'.format(
+                    self.accu_gen_avg_iter_time / self.cfg.logging_iter))
+
+                print('\tDiscriminator FWD time {:6f}'.format(
+                    self.accu_dis_forw_iter_time / self.cfg.logging_iter))
+                print('\tDiscriminator LOS time {:6f}'.format(
+                    self.accu_dis_loss_iter_time / self.cfg.logging_iter))
+                print('\tDiscriminator BCK time {:6f}'.format(
+                    self.accu_dis_back_iter_time / self.cfg.logging_iter))
+                print('\tDiscriminator STP time {:6f}'.format(
+                    self.accu_dis_step_iter_time / self.cfg.logging_iter))
+
+                print('{:6f}'.format(ave_t))
+
+                self.accu_gen_forw_iter_time = 0
+                self.accu_gen_loss_iter_time = 0
+                self.accu_gen_back_iter_time = 0
+                self.accu_gen_step_iter_time = 0
+                self.accu_gen_avg_iter_time = 0
+                self.accu_dis_forw_iter_time = 0
+                self.accu_dis_loss_iter_time = 0
+                self.accu_dis_back_iter_time = 0
+                self.accu_dis_step_iter_time = 0
+
+        self._end_of_iteration(data, current_epoch, current_iteration)
+
+        # Save everything to the checkpoint.
+        if current_iteration % self.cfg.snapshot_save_iter == 0:
+            if current_iteration >= self.cfg.snapshot_save_start_iter:
+                self.save_checkpoint(current_epoch, current_iteration)
+
+        # Compute metrics.
+        if current_iteration % self.cfg.metrics_iter == 0:
+            self.save_image(self._get_save_path('images', 'jpg'), data)
+            self.write_metrics()
+
+        # Compute image to be saved.
+        elif current_iteration % self.cfg.image_save_iter == 0:
+            self.save_image(self._get_save_path('images', 'jpg'), data)
+        elif current_iteration % self.cfg.image_display_iter == 0:
+            image_path = os.path.join(self.cfg.logdir, 'images', 'current.jpg')
+            self.save_image(image_path, data)
+
+        # Logging.
+        self._write_tensorboard()
+        if current_iteration % self.cfg.logging_iter == 0:
+            # Write all logs to tensorboard.
+            self._flush_meters(self.meters)
+
+        from torch.distributed import barrier
+        import torch.distributed as dist
+        if dist.is_initialized():
+            barrier()
+
+    def end_of_epoch(self, data, current_epoch, current_iteration):
+        r"""Things to do after an epoch.
+
+        Args:
+            data (dict): Data used for the current iteration.
+
+            current_epoch (int): Current number of epoch.
+            current_iteration (int): Current number of iteration.
+        """
+        # Update the learning rate policy for the generator if operating in the
+        # epoch mode.
+        self.current_iteration = current_iteration
+        self.current_epoch = current_epoch
+        if not self.cfg.gen_opt.lr_policy.iteration_mode:
+            self.sch_G.step()
+        # Update the learning rate policy for the discriminator if operating
+        # in the epoch mode.
+        if not self.cfg.dis_opt.lr_policy.iteration_mode:
+            self.sch_D.step()
+        elapsed_epoch_time = time.time() - self.start_epoch_time
+        # Logging.
+        print('Epoch: {}, total time: {:6f}.'.format(current_epoch,
+                                                     elapsed_epoch_time))
+        self.time_epoch = elapsed_epoch_time
+        self._end_of_epoch(data, current_epoch, current_iteration)
+
+        # Save everything to the checkpoint.
+        if current_iteration % self.cfg.snapshot_save_iter == 0:
+            if current_epoch >= self.cfg.snapshot_save_start_epoch:
+                self.save_checkpoint(current_epoch, current_iteration)
+
+        # Compute metrics.
+        if current_iteration % self.cfg.metrics_iter == 0:
+            self.save_image(self._get_save_path('images', 'jpg'), data)
+            self.write_metrics()
+
+    def pre_process(self, data):
+        r"""Custom data pre-processing function. Utilize this function if you
+        need to preprocess your data before sending it to the generator and
+        discriminator.
+
+        Args:
+            data (dict): Data used for the current iteration.
+        """
+
+    def recalculate_batch_norm_statistics(self, data_loader, averaged=True):
+        r"""Update the statistics in the moving average model.
+
+        Args:
+            data_loader (torch.utils.data.DataLoader): Data loader for
+                estimating the statistics.
+            averaged (Boolean): True/False, we recalculate batch norm statistics for EMA/regular
+        """
+        if not self.cfg.trainer.model_average_config.enabled:
+            return
+        if averaged:
+            net_G = self.net_G.module.averaged_model
+        else:
+            net_G = self.net_G_module
+        model_average_iteration = \
+            self.cfg.trainer.model_average_config.num_batch_norm_estimation_iterations
+        if model_average_iteration == 0:
+            return
+        with torch.no_grad():
+            # Accumulate bn stats..
+            net_G.train()
+            # Reset running stats.
+            net_G.apply(reset_batch_norm)
+            for cal_it, cal_data in enumerate(data_loader):
+                if cal_it >= model_average_iteration:
+                    print('Done with {} iterations of updating batch norm '
+                          'statistics'.format(model_average_iteration))
+                    break
+                cal_data = to_device(cal_data, 'cuda')
+                cal_data = self.pre_process(cal_data)
+                # Averaging over all batches
+                net_G.apply(calibrate_batch_norm_momentum)
+                net_G(cal_data)
+
+    def save_image(self, path, data):
+        r"""Compute visualization images and save them to the disk.
+
+        Args:
+            path (str): Location of the file.
+            data (dict): Data used for the current iteration.
+        """
+        self.net_G.eval()
+        vis_images = self._get_visualizations(data)
+        if is_master() and vis_images is not None:
+            vis_images = torch.cat(
+                [img for img in vis_images if img is not None], dim=3).float()
+            vis_images = (vis_images + 1) / 2
+            print('Save output images to {}'.format(path))
+            vis_images.clamp_(0, 1)
+            os.makedirs(os.path.dirname(path), exist_ok=True)
+            image_grid = torchvision.utils.make_grid(
+                vis_images, nrow=1, padding=0, normalize=False)
+            if self.cfg.trainer.image_to_tensorboard:
+                self.image_meter.write_image(image_grid, self.current_iteration)
+            torchvision.utils.save_image(image_grid, path, nrow=1)
+            wandb.log({os.path.splitext(os.path.basename(path))[0]: [wandb.Image(path)]})
+
+    def write_metrics(self):
+        r"""Write metrics to the tensorboard."""
+        cur_fid = self._compute_fid()
+        if cur_fid is not None:
+            if self.best_fid is not None:
+                self.best_fid = min(self.best_fid, cur_fid)
+            else:
+                self.best_fid = cur_fid
+            metric_dict = {'FID': cur_fid, 'best_FID': self.best_fid}
+            self._write_to_meters(metric_dict, self.metric_meters, reduce=False)
+            self._flush_meters(self.metric_meters)
+
+    def _get_save_path(self, subdir, ext):
+        r"""Get the image save path.
+
+        Args:
+            subdir (str): Sub-directory under the main directory for saving
+                the outputs.
+            ext (str): Filename extension for the image (e.g., jpg, png, ...).
+        Return:
+            (str): image filename to be used to save the visualization results.
+        """
+        subdir_path = os.path.join(self.cfg.logdir, subdir)
+        if not os.path.exists(subdir_path):
+            os.makedirs(subdir_path, exist_ok=True)
+        return os.path.join(
+            subdir_path, 'epoch_{:05}_iteration_{:09}.{}'.format(
+                self.current_epoch, self.current_iteration, ext))
+
+    def _get_outputs(self, net_D_output, real=True):
+        r"""Return output values. Note that when the gan mode is relativistic.
+        It will do the difference before returning.
+
+        Args:
+           net_D_output (dict):
+               real_outputs (tensor): Real output values.
+               fake_outputs (tensor): Fake output values.
+           real (bool): Return real or fake.
+        """
+
+        def _get_difference(a, b):
+            r"""Get difference between two lists of tensors or two tensors.
+
+            Args:
+                a: list of tensors or tensor
+                b: list of tensors or tensor
+            """
+            out = list()
+            for x, y in zip(a, b):
+                if isinstance(x, list):
+                    res = _get_difference(x, y)
+                else:
+                    res = x - y
+                out.append(res)
+            return out
+
+        if real:
+            if self.cfg.trainer.gan_relativistic:
+                return _get_difference(net_D_output['real_outputs'], net_D_output['fake_outputs'])
+            else:
+                return net_D_output['real_outputs']
+        else:
+            if self.cfg.trainer.gan_relativistic:
+                return _get_difference(net_D_output['fake_outputs'], net_D_output['real_outputs'])
+            else:
+                return net_D_output['fake_outputs']
+
+    def _start_of_epoch(self, current_epoch):
+        r"""Operations to do before starting an epoch.
+
+        Args:
+            current_epoch (int): Current number of epoch.
+        """
+        pass
+
+    def _start_of_iteration(self, data, current_iteration):
+        r"""Operations to do before starting an iteration.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_iteration (int): Current epoch number.
+        Returns:
+            (dict): Data used for the current iteration. They might be
+                processed by the custom _start_of_iteration function.
+        """
+        return data
+
+    def _end_of_iteration(self, data, current_epoch, current_iteration):
+        r"""Operations to do after an iteration.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_epoch (int): Current number of epoch.
+            current_iteration (int): Current epoch number.
+        """
+        pass
+
+    def _end_of_epoch(self, data, current_epoch, current_iteration):
+        r"""Operations to do after an epoch.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_epoch (int): Current number of epoch.
+            current_iteration (int): Current epoch number.
+        """
+        pass
+
+    def _get_visualizations(self, data):
+        r"""Compute visualization outputs.
+
+        Args:
+            data (dict): Data used for the current iteration.
+        """
+        return None
+
+    def _compute_fid(self):
+        r"""FID computation function to be overloaded."""
+        return None
+
+    def _init_loss(self, cfg):
+        r"""Every trainer should implement its own init loss function."""
+        raise NotImplementedError
+
+    def gen_update(self, data):
+        r"""Update the generator.
+
+        Args:
+            data (dict): Data used for the current iteration.
+        """
+        update_finished = False
+        while not update_finished:
+            # Set requires_grad flags.
+            requires_grad(self.net_G_module, True)
+            requires_grad(self.net_D, False)
+
+            # Compute the loss.
+            self._time_before_forward()
+            with autocast(enabled=self.cfg.trainer.amp_config.enabled):
+                total_loss = self.gen_forward(data)
+            if total_loss is None:
+                return
+
+            # Zero-grad and backpropagate the loss.
+            self.opt_G.zero_grad(set_to_none=True)
+            self._time_before_backward()
+            self.scaler_G.scale(total_loss).backward()
+
+            # Optionally clip gradient norm.
+            if hasattr(self.cfg.gen_opt, 'clip_grad_norm'):
+                self.scaler_G.unscale_(self.opt_G)
+                total_norm = torch.nn.utils.clip_grad_norm_(
+                    self.net_G_module.parameters(),
+                    self.cfg.gen_opt.clip_grad_norm
+                )
+                self.gen_grad_norm = total_norm
+                if torch.isfinite(total_norm) and \
+                        total_norm > self.cfg.gen_opt.clip_grad_norm:
+                    # print(f"Gradient norm of the generator ({total_norm}) "
+                    #       f"too large.")
+                    if getattr(self.cfg.gen_opt, 'skip_grad', False):
+                        print(f"Skip gradient update.")
+                        self.opt_G.zero_grad(set_to_none=True)
+                        self.scaler_G.step(self.opt_G)
+                        self.scaler_G.update()
+                        break
+                    # else:
+                    #     print(f"Clip gradient norm to "
+                    #           f"{self.cfg.gen_opt.clip_grad_norm}.")
+
+            # Perform an optimizer step.
+            self._time_before_step()
+            self.scaler_G.step(self.opt_G)
+            self.scaler_G.update()
+            # Whether the step above was skipped.
+            if self.last_step_count_G == self.opt_G._step_count:
+                print("Generator overflowed!")
+                if not torch.isfinite(total_loss):
+                    print("Generator loss is not finite. Skip this iteration!")
+                    update_finished = True
+            else:
+                self.last_step_count_G = self.opt_G._step_count
+                update_finished = True
+
+        self._extra_gen_step(data)
+
+        # Update model average.
+        self._time_before_model_avg()
+        if self.cfg.trainer.model_average_config.enabled:
+            self.net_G.module.update_average()
+
+        self._detach_losses()
+        self._time_before_leave_gen()
+
+    def gen_forward(self, data):
+        r"""Every trainer should implement its own generator forward."""
+        raise NotImplementedError
+
+    def _extra_gen_step(self, data):
+        pass
+
+    def dis_update(self, data):
+        r"""Update the discriminator.
+
+        Args:
+            data (dict): Data used for the current iteration.
+        """
+        update_finished = False
+        while not update_finished:
+            # Set requires_grad flags.
+            requires_grad(self.net_G_module, False)
+            requires_grad(self.net_D, True)
+
+            # Compute the loss.
+            self._time_before_forward()
+            with autocast(enabled=self.cfg.trainer.amp_config.enabled):
+                total_loss = self.dis_forward(data)
+            if total_loss is None:
+                return
+
+            # Zero-grad and backpropagate the loss.
+            self.opt_D.zero_grad(set_to_none=True)
+            self._time_before_backward()
+            self.scaler_D.scale(total_loss).backward()
+
+            # Optionally clip gradient norm.
+            if hasattr(self.cfg.dis_opt, 'clip_grad_norm'):
+                self.scaler_D.unscale_(self.opt_D)
+                total_norm = torch.nn.utils.clip_grad_norm_(
+                    self.net_D.parameters(), self.cfg.dis_opt.clip_grad_norm
+                )
+                self.dis_grad_norm = total_norm
+                if torch.isfinite(total_norm) and \
+                        total_norm > self.cfg.dis_opt.clip_grad_norm:
+                    print(f"Gradient norm of the discriminator ({total_norm}) "
+                          f"too large.")
+                    if getattr(self.cfg.dis_opt, 'skip_grad', False):
+                        print(f"Skip gradient update.")
+                        self.opt_D.zero_grad(set_to_none=True)
+                        self.scaler_D.step(self.opt_D)
+                        self.scaler_D.update()
+                        continue
+                    else:
+                        print(f"Clip gradient norm to "
+                              f"{self.cfg.dis_opt.clip_grad_norm}.")
+
+            # Perform an optimizer step.
+            self._time_before_step()
+            self.scaler_D.step(self.opt_D)
+            self.scaler_D.update()
+            # Whether the step above was skipped.
+            if self.last_step_count_D == self.opt_D._step_count:
+                print("Discriminator overflowed!")
+                if not torch.isfinite(total_loss):
+                    print("Discriminator loss is not finite. "
+                          "Skip this iteration!")
+                    update_finished = True
+            else:
+                self.last_step_count_D = self.opt_D._step_count
+                update_finished = True
+
+        self._extra_dis_step(data)
+
+        self._detach_losses()
+        self._time_before_leave_dis()
+
+    def dis_forward(self, data):
+        r"""Every trainer should implement its own discriminator forward."""
+        raise NotImplementedError
+
+    def _extra_dis_step(self, data):
+        pass
+
+    def test(self, data_loader, output_dir, inference_args):
+        r"""Compute results images for a batch of input data and save the
+        results in the specified folder.
+
+        Args:
+            data_loader (torch.utils.data.DataLoader): PyTorch dataloader.
+            output_dir (str): Target location for saving the output image.
+        """
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G = self.net_G.module.averaged_model
+        else:
+            net_G = self.net_G.module
+        net_G.eval()
+
+        print('# of samples %d' % len(data_loader))
+        for it, data in enumerate(tqdm(data_loader)):
+            data = self.start_of_iteration(data, current_iteration=-1)
+            with torch.no_grad():
+                output_images, file_names = \
+                    net_G.inference(data, **vars(inference_args))
+            for output_image, file_name in zip(output_images, file_names):
+                fullname = os.path.join(output_dir, file_name + '.jpg')
+                output_image = tensor2pilimage(output_image.clamp_(-1, 1),
+                                               minus1to1_normalized=True)
+                save_pilimage_in_jpeg(fullname, output_image)
+
+    def _get_total_loss(self, gen_forward):
+        r"""Return the total loss to be backpropagated.
+        Args:
+            gen_forward (bool): If ``True``, backpropagates the generator loss,
+                otherwise the discriminator loss.
+        """
+        losses = self.gen_losses if gen_forward else self.dis_losses
+        total_loss = torch.tensor(0., device=torch.device('cuda'))
+        # Iterates over all possible losses.
+        for loss_name in self.weights:
+            # If it is for the current model (gen/dis).
+            if loss_name in losses:
+                # Multiply it with the corresponding weight
+                # and add it to the total loss.
+                total_loss += losses[loss_name] * self.weights[loss_name]
+        losses['total'] = total_loss  # logging purpose
+        return total_loss
+
+    def _detach_losses(self):
+        r"""Detach all logging variables to prevent potential memory leak."""
+        for loss_name in self.gen_losses:
+            self.gen_losses[loss_name] = self.gen_losses[loss_name].detach()
+        for loss_name in self.dis_losses:
+            self.dis_losses[loss_name] = self.dis_losses[loss_name].detach()
+
+    def _time_before_forward(self):
+        r"""
+        Record time before applying forward.
+        """
+        if self.cfg.speed_benchmark:
+            torch.cuda.synchronize()
+            self.forw_time = time.time()
+
+    def _time_before_loss(self):
+        r"""
+        Record time before computing loss.
+        """
+        if self.cfg.speed_benchmark:
+            torch.cuda.synchronize()
+            self.loss_time = time.time()
+
+    def _time_before_backward(self):
+        r"""
+        Record time before applying backward.
+        """
+        if self.cfg.speed_benchmark:
+            torch.cuda.synchronize()
+            self.back_time = time.time()
+
+    def _time_before_step(self):
+        r"""
+        Record time before updating the weights
+        """
+        if self.cfg.speed_benchmark:
+            torch.cuda.synchronize()
+            self.step_time = time.time()
+
+    def _time_before_model_avg(self):
+        r"""
+        Record time before applying model average.
+        """
+        if self.cfg.speed_benchmark:
+            torch.cuda.synchronize()
+            self.avg_time = time.time()
+
+    def _time_before_leave_gen(self):
+        r"""
+        Record forward, backward, loss, and model average time for the
+        generator update.
+        """
+        if self.cfg.speed_benchmark:
+            torch.cuda.synchronize()
+            end_time = time.time()
+            self.accu_gen_forw_iter_time += self.loss_time - self.forw_time
+            self.accu_gen_loss_iter_time += self.back_time - self.loss_time
+            self.accu_gen_back_iter_time += self.step_time - self.back_time
+            self.accu_gen_step_iter_time += self.avg_time - self.step_time
+            self.accu_gen_avg_iter_time += end_time - self.avg_time
+
+    def _time_before_leave_dis(self):
+        r"""
+        Record forward, backward, loss time for the discriminator update.
+        """
+        if self.cfg.speed_benchmark:
+            torch.cuda.synchronize()
+            end_time = time.time()
+            self.accu_dis_forw_iter_time += self.loss_time - self.forw_time
+            self.accu_dis_loss_iter_time += self.back_time - self.loss_time
+            self.accu_dis_back_iter_time += self.step_time - self.back_time
+            self.accu_dis_step_iter_time += end_time - self.step_time
+
+
+@master_only
+def _save_checkpoint(cfg,
+                     net_G, net_D,
+                     opt_G, opt_D,
+                     sch_G, sch_D,
+                     current_epoch, current_iteration):
+    r"""Save network weights, optimizer parameters, scheduler parameters
+    in the checkpoint.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        current_epoch (int): Current epoch.
+        current_iteration (int): Current iteration.
+    """
+    latest_checkpoint_path = 'epoch_{:05}_iteration_{:09}_checkpoint.pt'.format(
+        current_epoch, current_iteration)
+    save_path = os.path.join(cfg.logdir, latest_checkpoint_path)
+    torch.save(
+        {
+            'net_G': net_G.state_dict(),
+            'net_D': net_D.state_dict(),
+            'opt_G': opt_G.state_dict(),
+            'opt_D': opt_D.state_dict(),
+            'sch_G': sch_G.state_dict(),
+            'sch_D': sch_D.state_dict(),
+            'current_epoch': current_epoch,
+            'current_iteration': current_iteration,
+        },
+        save_path,
+    )
+    fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt')
+    with open(fn, 'wt') as f:
+        f.write('latest_checkpoint: %s' % latest_checkpoint_path)
+    print('Save checkpoint to {}'.format(save_path))
+    return save_path
diff --git a/imaginaire/trainers/fs_vid2vid.py b/imaginaire/trainers/fs_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f30b4dd588667b8cef1e433c5503f9ac419e190
--- /dev/null
+++ b/imaginaire/trainers/fs_vid2vid.py
@@ -0,0 +1,292 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import imageio
+import numpy as np
+import torch
+from tqdm import tqdm
+
+
+from imaginaire.model_utils.fs_vid2vid import (concat_frames, get_fg_mask,
+                                               pre_process_densepose,
+                                               random_roll)
+from imaginaire.model_utils.pix2pixHD import get_optimizer_with_params
+from imaginaire.trainers.vid2vid import Trainer as vid2vidTrainer
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.misc import to_cuda
+from imaginaire.utils.visualization import tensor2flow, tensor2im
+
+
+class Trainer(vid2vidTrainer):
+    r"""Initialize vid2vid trainer.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                 train_data_loader, val_data_loader):
+        super(Trainer, self).__init__(cfg, net_G, net_D, opt_G,
+                                      opt_D, sch_G, sch_D,
+                                      train_data_loader, val_data_loader)
+
+    def _start_of_iteration(self, data, current_iteration):
+        r"""Things to do before an iteration.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_iteration (int): Current number of iteration.
+        """
+        data = self.pre_process(data)
+        return to_cuda(data)
+
+    def pre_process(self, data):
+        r"""Do any data pre-processing here.
+
+        Args:
+            data (dict): Data used for the current iteration.
+        """
+        data_cfg = self.cfg.data
+        if hasattr(data_cfg, 'for_pose_dataset') and \
+                ('pose_maps-densepose' in data_cfg.input_labels):
+            pose_cfg = data_cfg.for_pose_dataset
+            data['label'] = pre_process_densepose(pose_cfg, data['label'],
+                                                  self.is_inference)
+            data['few_shot_label'] = pre_process_densepose(
+                pose_cfg, data['few_shot_label'], self.is_inference)
+        return data
+
+    def get_test_output_images(self, data):
+        r"""Get the visualization output of test function.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        vis_images = [
+            tensor2im(data['few_shot_images'][:, 0]),
+            self.visualize_label(data['label'][:, -1]),
+            tensor2im(data['images'][:, -1]),
+            tensor2im(self.net_G_output['fake_images']),
+        ]
+        return vis_images
+
+    def get_data_t(self, data, net_G_output, data_prev, t):
+        r"""Get data at current time frame given the sequence of data.
+
+        Args:
+            data (dict): Training data for current iteration.
+            net_G_output (dict): Output of the generator (for previous frame).
+            data_prev (dict): Data for previous frame.
+            t (int): Current time.
+        """
+        label = data['label'][:, t] if 'label' in data else None
+        image = data['images'][:, t]
+
+        if data_prev is not None:
+            nG = self.cfg.data.num_frames_G
+            prev_labels = concat_frames(data_prev['prev_labels'],
+                                        data_prev['label'], nG - 1)
+            prev_images = concat_frames(
+                data_prev['prev_images'],
+                net_G_output['fake_images'].detach(), nG - 1)
+        else:
+            prev_labels = prev_images = None
+
+        data_t = dict()
+        data_t['label'] = label
+        data_t['image'] = image
+        data_t['ref_labels'] = data['few_shot_label'] if 'few_shot_label' \
+                                                         in data else None
+        data_t['ref_images'] = data['few_shot_images']
+        data_t['prev_labels'] = prev_labels
+        data_t['prev_images'] = prev_images
+        data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None
+
+        # if 'landmarks_xy' in data:
+        #     data_t['landmarks_xy'] = data['landmarks_xy'][:, t]
+        #     data_t['ref_landmarks_xy'] = data['few_shot_landmarks_xy']
+        return data_t
+
+    def post_process(self, data, net_G_output):
+        r"""Do any postprocessing of the data / output here.
+
+        Args:
+            data (dict): Training data at the current iteration.
+            net_G_output (dict): Output of the generator.
+        """
+        if self.has_fg:
+            fg_mask = get_fg_mask(data['label'], self.has_fg)
+            if net_G_output['fake_raw_images'] is not None:
+                net_G_output['fake_raw_images'] = \
+                    net_G_output['fake_raw_images'] * fg_mask
+
+        return data, net_G_output
+
+    def test(self, test_data_loader, root_output_dir, inference_args):
+        r"""Run inference on the specified sequence.
+
+        Args:
+            test_data_loader (object): Test data loader.
+            root_output_dir (str): Location to dump outputs.
+            inference_args (optional): Optional args.
+        """
+        self.reset()
+        test_data_loader.dataset.set_sequence_length(0)
+        # Set the inference sequences.
+        test_data_loader.dataset.set_inference_sequence_idx(
+            inference_args.driving_seq_index,
+            inference_args.few_shot_seq_index,
+            inference_args.few_shot_frame_index)
+
+        video = []
+        for idx, data in enumerate(tqdm(test_data_loader)):
+            key = data['key']['images'][0][0]
+            filename = key.split('/')[-1]
+
+            # Create output dir for this sequence.
+            if idx == 0:
+                seq_name = '%03d' % inference_args.driving_seq_index
+                output_dir = os.path.join(root_output_dir, seq_name)
+                os.makedirs(output_dir, exist_ok=True)
+                video_path = output_dir
+
+            # Get output and save images.
+            data['img_name'] = filename
+            data = self.start_of_iteration(data, current_iteration=-1)
+            output = self.test_single(data, output_dir, inference_args)
+            video.append(output)
+
+        # Save output as mp4.
+        imageio.mimsave(video_path + '.mp4', video, fps=15)
+
+    def save_image(self, path, data):
+        r"""Save the output images to path.
+        Note when the generate_raw_output is FALSE. Then,
+        first_net_G_output['fake_raw_images'] is None and will not be displayed.
+        In model average mode, we will plot the flow visualization twice.
+
+        Args:
+            path (str): Save path.
+            data (dict): Training data for current iteration.
+        """
+        self.net_G.eval()
+        if self.cfg.trainer.model_average_config.enabled:
+            self.net_G.module.averaged_model.eval()
+
+        self.net_G_output = None
+        with torch.no_grad():
+            first_net_G_output, last_net_G_output, _ = self.gen_frames(data)
+            if self.cfg.trainer.model_average_config.enabled:
+                first_net_G_output_avg, last_net_G_output_avg, _ = \
+                    self.gen_frames(data, use_model_average=True)
+
+        def get_images(data, net_G_output, return_first_frame=True,
+                       for_model_average=False):
+            r"""Get the ourput images to save.
+
+            Args:
+                data (dict): Training data for current iteration.
+                net_G_output (dict): Generator output.
+                return_first_frame (bool): Return output for first frame in the
+                sequence.
+                for_model_average (bool): For model average output.
+            Return:
+                vis_images (list of numpy arrays): Visualization images.
+            """
+            frame_idx = 0 if return_first_frame else -1
+            warped_idx = 0 if return_first_frame else 1
+            vis_images = []
+            if not for_model_average:
+                vis_images += [
+                    tensor2im(data['few_shot_images'][:, frame_idx]),
+                    self.visualize_label(data['label'][:, frame_idx]),
+                    tensor2im(data['images'][:, frame_idx])
+                ]
+            vis_images += [
+                tensor2im(net_G_output['fake_images']),
+                tensor2im(net_G_output['fake_raw_images'])]
+            if not for_model_average:
+                vis_images += [
+                    tensor2im(net_G_output['warped_images'][warped_idx]),
+                    tensor2flow(net_G_output['fake_flow_maps'][warped_idx]),
+                    tensor2im(net_G_output['fake_occlusion_masks'][warped_idx],
+                              normalize=False)
+                ]
+            return vis_images
+
+        if is_master():
+            vis_images_first = get_images(data, first_net_G_output)
+            if self.cfg.trainer.model_average_config.enabled:
+                vis_images_first += get_images(data, first_net_G_output_avg,
+                                               for_model_average=True)
+            if self.sequence_length > 1:
+                vis_images_last = get_images(data, last_net_G_output,
+                                             return_first_frame=False)
+                if self.cfg.trainer.model_average_config.enabled:
+                    vis_images_last += get_images(data, last_net_G_output_avg,
+                                                  return_first_frame=False,
+                                                  for_model_average=True)
+
+                # If generating a video, the first row of each batch will be
+                # the first generated frame and the flow/mask for warping the
+                # reference image, and the second row will be the last
+                # generated frame and the flow/mask for warping the previous
+                # frame. If using model average, the frames generated by model
+                # average will be at the rightmost columns.
+                vis_images = [[np.vstack((im_first, im_last))
+                               for im_first, im_last in
+                               zip(imgs_first, imgs_last)]
+                              for imgs_first, imgs_last in zip(vis_images_first,
+                                                               vis_images_last)
+                              if imgs_first is not None]
+            else:
+                vis_images = vis_images_first
+
+            image_grid = np.hstack([np.vstack(im) for im in vis_images
+                                    if im is not None])
+
+            print('Save output images to {}'.format(path))
+            os.makedirs(os.path.dirname(path), exist_ok=True)
+            imageio.imwrite(path, image_grid)
+
+    def finetune(self, data, inference_args):
+        r"""Finetune the model for a few iterations on the inference data."""
+        # Get the list of params to finetune.
+        self.net_G, self.net_D, self.opt_G, self.opt_D = \
+            get_optimizer_with_params(self.cfg, self.net_G, self.net_D,
+                                      param_names_start_with=[
+                                          'weight_generator.fc', 'conv_img',
+                                          'up'])
+        data_finetune = {k: v for k, v in data.items()}
+        ref_labels = data_finetune['few_shot_label']
+        ref_images = data_finetune['few_shot_images']
+
+        # Number of iterations to finetune.
+        iterations = getattr(inference_args, 'finetune_iter', 100)
+        for it in range(1, iterations + 1):
+            # Randomly set one of the reference images as target.
+            idx = np.random.randint(ref_labels.size(1))
+            tgt_label, tgt_image = ref_labels[:, idx], ref_images[:, idx]
+            # Randomly shift and flip the target image.
+            tgt_label, tgt_image = random_roll([tgt_label, tgt_image])
+            data_finetune['label'] = tgt_label.unsqueeze(1)
+            data_finetune['images'] = tgt_image.unsqueeze(1)
+
+            self.gen_update(data_finetune)
+            self.dis_update(data_finetune)
+            if (it % (iterations // 10)) == 0:
+                print(it)
+
+        self.has_finetuned = True
diff --git a/imaginaire/trainers/funit.py b/imaginaire/trainers/funit.py
new file mode 100644
index 0000000000000000000000000000000000000000..71141064cc92639120a8bd45b1a982bed0c5f37a
--- /dev/null
+++ b/imaginaire/trainers/funit.py
@@ -0,0 +1,244 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+import torch.nn.functional as F
+import numpy as np
+import os
+from imaginaire.evaluation import compute_fid, compute_kid
+from imaginaire.utils.diff_aug import apply_diff_aug
+from imaginaire.losses import GANLoss
+from imaginaire.trainers.base import BaseTrainer
+from imaginaire.utils.distributed import is_master
+
+
+class Trainer(BaseTrainer):
+    r"""Reimplementation of the FUNIT (https://arxiv.org/abs/1905.01723)
+    algorithm.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                 train_data_loader, val_data_loader):
+        self.best_kid = None
+        self.use_fid = getattr(cfg.trainer, 'use_fid', False)
+        self.use_kid = getattr(cfg.trainer, 'use_kid', True)
+        self.kid_num_subsets = getattr(cfg.trainer, 'kid_num_subsets', 1)
+        self.kid_sample_size = getattr(cfg.trainer, 'kid_sample_size', 256)
+        self.kid_subset_size = getattr(cfg.trainer, 'kid_subset_size', 256)
+        super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                         train_data_loader, val_data_loader)
+
+    def _init_loss(self, cfg):
+        r"""Initialize loss terms. In FUNIT, we have several loss terms
+        including the GAN loss, the image reconstruction loss, the feature
+        matching loss, and the gradient penalty loss.
+
+        Args:
+            cfg (obj): Global configuration.
+        """
+        self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode)
+        self.criteria['image_recon'] = nn.L1Loss()
+        self.criteria['feature_matching'] = nn.L1Loss()
+
+        for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items():
+            if loss_weight > 0:
+                self.weights[loss_name] = loss_weight
+
+    def gen_forward(self, data):
+        r"""Compute the loss for FUNIT generator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+
+        net_G_output = self.net_G(data)
+
+        # Differentiable augmentation.
+        keys = ['images_recon', 'images_trans']
+        net_D_output = self.net_D(data, apply_diff_aug(
+                                      net_G_output, keys, self.aug_policy))
+
+        self._time_before_loss()
+
+        # GAN loss
+        # We use both the translation and reconstruction streams.
+        if 'gan' in self.weights:
+            self.gen_losses['gan'] = 0.5 * (
+                    self.criteria['gan'](
+                        net_D_output['fake_out_trans'],
+                        True, dis_update=False) +
+                    self.criteria['gan'](
+                        net_D_output['fake_out_recon'],
+                        True, dis_update=False))
+
+        # Image reconstruction loss
+        if 'image_recon' in self.weights:
+            self.gen_losses['image_recon'] = \
+                self.criteria['image_recon'](net_G_output['images_recon'],
+                                             data['images_content'])
+
+        # Feature matching loss
+        if 'feature_matching' in self.weights:
+            self.gen_losses['feature_matching'] = \
+                self.criteria['feature_matching'](
+                    net_D_output['fake_features_trans'],
+                    net_D_output['real_features_style'])
+
+        # Compute total loss
+        total_loss = self._get_total_loss(gen_forward=True)
+        return total_loss
+
+    def dis_forward(self, data):
+        r"""Compute the loss for FUNIT discriminator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        with torch.no_grad():
+            net_G_output = self.net_G(data)
+        net_G_output['images_trans'].requires_grad = True
+        net_D_output = self.net_D(
+            apply_diff_aug(data, ['images_style'], self.aug_policy),
+            apply_diff_aug(net_G_output, ['images_trans'], self.aug_policy),
+            recon=False)
+
+        self._time_before_loss()
+
+        self.dis_losses['gan'] = \
+            self.criteria['gan'](net_D_output['real_out_style'], True) + \
+            self.criteria['gan'](net_D_output['fake_out_trans'], False)
+
+        # Compute total loss
+        total_loss = self._get_total_loss(gen_forward=False)
+        return total_loss
+
+    def _get_visualizations(self, data):
+        r"""Compute visualization image.
+
+        Args:
+            data (dict): The current batch.
+        """
+        net_G_for_evaluation = self.net_G
+        with torch.no_grad():
+            net_G_output = net_G_for_evaluation(data)
+            vis_images = [data['images_content'],
+                          data['images_style'],
+                          net_G_output['images_recon'],
+                          net_G_output['images_trans']]
+            _, _, h, w = net_G_output['images_recon'].size()
+            if 'attn_a' in net_G_output:
+                for i in range(net_G_output['attn_a'].size(1)):
+                    vis_images += [
+                        F.interpolate(
+                            net_G_output['attn_a'][:, i:i + 1, :, :], (
+                                h, w)).expand(-1, 3, -1, -1)]
+                for i in range(net_G_output['attn_a'].size(1)):
+                    vis_images += [
+                        F.interpolate(
+                            net_G_output['attn_b'][:, i:i + 1, :, :], (
+                                h, w)).expand(-1, 3, -1, -1)]
+            if self.cfg.trainer.model_average_config.enabled:
+                net_G_for_evaluation = self.net_G.module.averaged_model
+                net_G_output = net_G_for_evaluation(data)
+                vis_images += [net_G_output['images_recon'],
+                               net_G_output['images_trans']]
+            return vis_images
+
+    def _compute_fid(self):
+        r"""Compute FID. We will compute a FID value per test class. That is
+        if you have 30 test classes, we will compute 30 different FID values.
+        We will then report the mean of the FID values as the final
+        performance number as described in the FUNIT paper.
+        """
+        self.net_G.eval()
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G_for_evaluation = self.net_G.module.averaged_model
+        else:
+            net_G_for_evaluation = self.net_G
+
+        all_fid_values = []
+        num_test_classes = self.val_data_loader.dataset.num_style_classes
+        for class_idx in range(num_test_classes):
+            fid_path = self._get_save_path(os.path.join('fid', str(class_idx)),
+                                           'npy')
+            self.val_data_loader.dataset.set_sample_class_idx(class_idx)
+
+            fid_value = compute_fid(fid_path, self.val_data_loader,
+                                    net_G_for_evaluation, 'images_style',
+                                    'images_trans')
+            all_fid_values.append(fid_value)
+
+        if is_master():
+            mean_fid = np.mean(all_fid_values)
+            print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format(
+                self.current_epoch, self.current_iteration, mean_fid))
+            return mean_fid
+        else:
+            return None
+
+    def _compute_kid(self):
+        self.net_G.eval()
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G_for_evaluation = self.net_G.module.averaged_model
+        else:
+            net_G_for_evaluation = self.net_G
+
+        all_kid_values = []
+        num_test_classes = self.val_data_loader.dataset.num_style_classes
+        for class_idx in range(num_test_classes):
+            kid_path = self._get_save_path(os.path.join('kid', str(class_idx)),
+                                           'npy')
+            self.val_data_loader.dataset.set_sample_class_idx(class_idx)
+
+            kid_value = compute_kid(
+                kid_path, self.val_data_loader, net_G_for_evaluation,
+                'images_style', 'images_trans',
+                num_subsets=self.kid_num_subsets,
+                sample_size=self.kid_sample_size,
+                subset_size=self.kid_subset_size)
+            all_kid_values.append(kid_value)
+
+        if is_master():
+            mean_kid = np.mean(all_kid_values)
+            print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format(
+                self.current_epoch, self.current_iteration, mean_kid))
+            return mean_kid
+        else:
+            return None
+
+    def write_metrics(self):
+        r"""Write metrics to the tensorboard."""
+        metric_dict = {}
+        if self.use_kid:
+            cur_kid = self._compute_kid()
+            if cur_kid is not None:
+                if self.best_kid is not None:
+                    self.best_kid = min(self.best_kid, cur_kid)
+                else:
+                    self.best_kid = cur_kid
+                metric_dict.update({'KID': cur_kid, 'best_KID': self.best_kid})
+        if self.use_fid:
+            cur_fid = self._compute_fid()
+            if cur_fid is not None:
+                if self.best_fid is not None:
+                    self.best_fid = min(self.best_fid, cur_fid)
+                else:
+                    self.best_fid = cur_fid
+                metric_dict.update({'FID': cur_fid, 'best_FID': self.best_fid})
+
+        if is_master():
+            self._write_to_meters(metric_dict, self.metric_meters)
+            self._flush_meters(self.metric_meters)
diff --git a/imaginaire/trainers/gancraft.py b/imaginaire/trainers/gancraft.py
new file mode 100644
index 0000000000000000000000000000000000000000..167e26161493ea82901d9d8e14f030ad71444c0e
--- /dev/null
+++ b/imaginaire/trainers/gancraft.py
@@ -0,0 +1,327 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import collections
+import os
+
+import torch
+import torch.nn as nn
+
+from imaginaire.config import Config
+from imaginaire.generators.spade import Generator as SPADEGenerator
+from imaginaire.losses import (FeatureMatchingLoss, GaussianKLLoss, PerceptualLoss)
+from imaginaire.model_utils.gancraft.loss import GANLoss
+from imaginaire.trainers.base import BaseTrainer
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.io import get_checkpoint
+from imaginaire.utils.misc import split_labels, to_device
+from imaginaire.utils.trainer import ModelAverage, WrappedModel
+from imaginaire.utils.visualization import tensor2label
+
+
+class GauGANLoader(object):
+    r"""Manages the SPADE/GauGAN model used to generate pseudo-GTs for training GANcraft.
+
+    Args:
+        gaugan_cfg (Config): SPADE configuration.
+    """
+
+    def __init__(self, gaugan_cfg):
+        print('[GauGANLoader] Loading GauGAN model.')
+        cfg = Config(gaugan_cfg.config)
+        default_checkpoint_path = os.path.basename(gaugan_cfg.config).split('.yaml')[0] + '-' + \
+            cfg.pretrained_weight + '.pt'
+        checkpoint = get_checkpoint(default_checkpoint_path, cfg.pretrained_weight)
+        ckpt = torch.load(checkpoint)
+
+        net_G = WrappedModel(ModelAverage(SPADEGenerator(cfg.gen, cfg.data).to('cuda')))
+        net_G.load_state_dict(ckpt['net_G'])
+        self.net_GG = net_G.module.averaged_model
+        self.net_GG.eval()
+        self.net_GG.half()
+        print('[GauGANLoader] GauGAN loading complete.')
+
+    def eval(self, label, z=None, style_img=None):
+        r"""Produce output given segmentation and other conditioning inputs.
+        random style will be used if neither z nor style_img is provided.
+
+        Args:
+            label (N x C x H x W tensor): One-hot segmentation mask of shape.
+            z: Style vector.
+            style_img: Style image.
+        """
+        inputs = {'label': label[:, :-1].detach().half()}
+        random_style = True
+
+        if z is not None:
+            random_style = False
+            inputs['z'] = z.detach().half()
+        elif style_img is not None:
+            random_style = False
+            inputs['images'] = style_img.detach().half()
+
+        net_GG_output = self.net_GG(inputs, random_style=random_style)
+
+        return net_GG_output['fake_images']
+
+
+class Trainer(BaseTrainer):
+    r"""Initialize GANcraft trainer.
+
+    Args:
+        cfg (Config): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self,
+                 cfg,
+                 net_G,
+                 net_D,
+                 opt_G,
+                 opt_D,
+                 sch_G,
+                 sch_D,
+                 train_data_loader,
+                 val_data_loader):
+        super(Trainer, self).__init__(cfg, net_G, net_D, opt_G,
+                                      opt_D, sch_G, sch_D,
+                                      train_data_loader, val_data_loader)
+
+        # Load the pseudo-GT network only if in training mode, else not needed.
+        if not self.is_inference:
+            self.gaugan_model = GauGANLoader(cfg.trainer.gaugan_loader)
+
+    def _init_loss(self, cfg):
+        r"""Initialize loss terms.
+
+        Args:
+            cfg (obj): Global configuration.
+        """
+        if hasattr(cfg.trainer.loss_weight, 'gan'):
+            self.criteria['GAN'] = GANLoss()
+            self.weights['GAN'] = cfg.trainer.loss_weight.gan
+        if hasattr(cfg.trainer.loss_weight, 'pseudo_gan'):
+            self.criteria['PGAN'] = GANLoss()
+            self.weights['PGAN'] = cfg.trainer.loss_weight.pseudo_gan
+        if hasattr(cfg.trainer.loss_weight, 'l2'):
+            self.criteria['L2'] = nn.MSELoss()
+            self.weights['L2'] = cfg.trainer.loss_weight.l2
+        if hasattr(cfg.trainer.loss_weight, 'l1'):
+            self.criteria['L1'] = nn.L1Loss()
+            self.weights['L1'] = cfg.trainer.loss_weight.l1
+        if hasattr(cfg.trainer.loss_weight, 'TV')
+        if hasattr(cfg.trainer, 'perceptual_loss'):
+            self.criteria['Perceptual'] = \
+                PerceptualLoss(
+                    network=cfg.trainer.perceptual_loss.mode,
+                    layers=cfg.trainer.perceptual_loss.layers,
+                    weights=cfg.trainer.perceptual_loss.weights)
+            self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual
+        # Setup the feature matching loss.
+        if hasattr(cfg.trainer.loss_weight, 'feature_matching'):
+            self.criteria['FeatureMatching'] = FeatureMatchingLoss()
+            self.weights['FeatureMatching'] = \
+                cfg.trainer.loss_weight.feature_matching
+        # Setup the Gaussian KL divergence loss.
+        if hasattr(cfg.trainer.loss_weight, 'kl'):
+            self.criteria['GaussianKL'] = GaussianKLLoss()
+            self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl
+
+    def _start_of_epoch(self, current_epoch):
+        torch.cuda.empty_cache()  # Prevent the first iteration from running OOM.
+
+    def _start_of_iteration(self, data, current_iteration):
+        r"""Model specific custom start of iteration process. We will do two
+        things. First, put all the data to GPU. Second, we will resize the
+        input so that it becomes multiple of the factor for bug-free
+        convolutional operations. This factor is given by the yaml file.
+        E.g., base = getattr(self.net_G, 'base', 32)
+
+        Args:
+            data (dict): The current batch.
+            current_iteration (int): The iteration number of the current batch.
+        """
+        data = to_device(data, 'cuda')
+
+        # Sample camera poses and pseudo-GTs.
+        with torch.no_grad():
+            samples = self.net_G.module.sample_camera(data, self.gaugan_model.eval)
+
+        return {**data, **samples}
+
+    def gen_forward(self, data):
+        r"""Compute the loss for SPADE generator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        net_G_output = self.net_G(data, random_style=False)
+
+        self._time_before_loss()
+
+        if 'GAN' in self.criteria or 'PGAN' in self.criteria:
+            incl_pseudo_real = False
+            if 'FeatureMatching' in self.criteria:
+                incl_pseudo_real = True
+            net_D_output = self.net_D(data, net_G_output, incl_real=False, incl_pseudo_real=incl_pseudo_real)
+            output_fake = net_D_output['fake_outputs']  # Choose from real_outputs and fake_outputs.
+
+            gan_loss = self.criteria['GAN'](output_fake, True, dis_update=False)
+            if 'GAN' in self.criteria:
+                self.gen_losses['GAN'] = gan_loss
+            if 'PGAN' in self.criteria:
+                self.gen_losses['PGAN'] = gan_loss
+
+        if 'FeatureMatching' in self.criteria:
+            self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching'](
+                net_D_output['fake_features'], net_D_output['pseudo_real_features'])
+
+        if 'GaussianKL' in self.criteria:
+            self.gen_losses['GaussianKL'] = self.criteria['GaussianKL'](net_G_output['mu'], net_G_output['logvar'])
+
+        # Perceptual loss is always between fake image and pseudo real image.
+        if 'Perceptual' in self.criteria:
+            self.gen_losses['Perceptual'] = self.criteria['Perceptual'](
+                net_G_output['fake_images'], data['pseudo_real_img'])
+
+        # Reconstruction loss between fake and pseudo real.
+        if 'L2' in self.criteria:
+            self.gen_losses['L2'] = self.criteria['L2'](net_G_output['fake_images'], data['pseudo_real_img'])
+        if 'L1' in self.criteria:
+            self.gen_losses['L1'] = self.criteria['L1'](net_G_output['fake_images'], data['pseudo_real_img'])
+
+        total_loss = 0
+        for key in self.criteria:
+            total_loss = total_loss + self.gen_losses[key] * self.weights[key]
+
+        self.gen_losses['total'] = total_loss
+        return total_loss
+
+    def dis_forward(self, data):
+        r"""Compute the loss for GANcraft discriminator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        if 'GAN' not in self.criteria and 'PGAN' not in self.criteria:
+            return
+
+        with torch.no_grad():
+            net_G_output = self.net_G(data, random_style=False)
+            net_G_output['fake_images'] = net_G_output['fake_images'].detach()
+
+        incl_real = False
+        incl_pseudo_real = False
+        if 'GAN' in self.criteria:
+            incl_real = True
+        if 'PGAN' in self.criteria:
+            incl_pseudo_real = True
+        net_D_output = self.net_D(data, net_G_output, incl_real=incl_real, incl_pseudo_real=incl_pseudo_real)
+
+        self._time_before_loss()
+        total_loss = 0
+        if 'GAN' in self.criteria:
+            output_fake = net_D_output['fake_outputs']
+            output_real = net_D_output['real_outputs']
+
+            fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True)
+            true_loss = self.criteria['GAN'](output_real, True, dis_update=True)
+            self.dis_losses['GAN/fake'] = fake_loss
+            self.dis_losses['GAN/true'] = true_loss
+            self.dis_losses['GAN'] = fake_loss + true_loss
+            total_loss = total_loss + self.dis_losses['GAN'] * self.weights['GAN']
+        if 'PGAN' in self.criteria:
+            output_fake = net_D_output['fake_outputs']
+            output_pseudo_real = net_D_output['pseudo_real_outputs']
+
+            fake_loss = self.criteria['PGAN'](output_fake, False, dis_update=True)
+            true_loss = self.criteria['PGAN'](output_pseudo_real, True, dis_update=True)
+            self.dis_losses['PGAN/fake'] = fake_loss
+            self.dis_losses['PGAN/true'] = true_loss
+            self.dis_losses['PGAN'] = fake_loss + true_loss
+            total_loss = total_loss + self.dis_losses['PGAN'] * self.weights['PGAN']
+
+        self.dis_losses['total'] = total_loss
+        return total_loss
+
+    def _get_visualizations(self, data):
+        r"""Compute visualization image.
+
+        Args:
+            data (dict): The current batch.
+        """
+        with torch.no_grad():
+            label_lengths = self.train_data_loader.dataset.get_label_lengths()
+            labels = split_labels(data['label'], label_lengths)
+
+            # Get visualization of the real image and segmentation mask.
+            segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True)
+            segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0)
+
+            # Get output from GANcraft model
+            net_G_output_randstyle = self.net_G(data, random_style=True)
+            net_G_output = self.net_G(data, random_style=False)
+
+            vis_images = [data['images'], segmap, net_G_output_randstyle['fake_images'], net_G_output['fake_images']]
+
+            if 'fake_masks' in data:
+                # Get pseudo-GT.
+                labels = split_labels(data['fake_masks'], label_lengths)
+                segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True)
+                segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0)
+                vis_images.append(segmap)
+
+            if 'pseudo_real_img' in data:
+                vis_images.append(data['pseudo_real_img'])
+
+            if self.cfg.trainer.model_average_config.enabled:
+                net_G_model_average_output = self.net_G.module.averaged_model(data, random_style=True)
+                vis_images.append(net_G_model_average_output['fake_images'])
+        return vis_images
+
+    def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True):
+        r"""Load network weights, optimizer parameters, scheduler parameters
+        from a checkpoint.
+
+        Args:
+            cfg (obj): Global configuration.
+            checkpoint_path (str): Path to the checkpoint.
+            resume (bool or None): If not ``None``, will determine whether or
+            not to load optimizers in addition to network weights.
+        """
+        ret = super().load_checkpoint(cfg, checkpoint_path, resume, load_sch)
+
+        if getattr(cfg.trainer, 'reset_opt_g_on_resume', False):
+            self.opt_G.state = collections.defaultdict(dict)
+            print('[GANcraft::load_checkpoint] Resetting opt_G.state')
+        if getattr(cfg.trainer, 'reset_opt_d_on_resume', False):
+            self.opt_D.state = collections.defaultdict(dict)
+            print('[GANcraft::load_checkpoint] Resetting opt_D.state')
+
+        return ret
+
+    def test(self, data_loader, output_dir, inference_args):
+        r"""Compute results images for a batch of input data and save the
+        results in the specified folder.
+
+        Args:
+            data_loader (torch.utils.data.DataLoader): PyTorch dataloader.
+            output_dir (str): Target location for saving the output image.
+        """
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G = self.net_G.module.averaged_model
+        else:
+            net_G = self.net_G.module
+        net_G.eval()
+
+        torch.cuda.empty_cache()
+        with torch.no_grad():
+            net_G.inference(output_dir, **vars(inference_args))
diff --git a/imaginaire/trainers/munit.py b/imaginaire/trainers/munit.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0bc4b42e1d2b82ac4702cb8f99808603818e3d9
--- /dev/null
+++ b/imaginaire/trainers/munit.py
@@ -0,0 +1,312 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+
+from imaginaire.evaluation import compute_fid
+from imaginaire.losses import (GANLoss, GaussianKLLoss,
+                               PerceptualLoss)
+from imaginaire.trainers.base import BaseTrainer
+from imaginaire.utils.misc import random_shift
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.diff_aug import apply_diff_aug
+
+
+class Trainer(BaseTrainer):
+    r"""Reimplementation of the MUNIT (https://arxiv.org/abs/1804.04732)
+    algorithm.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                 train_data_loader, val_data_loader):
+        super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                         train_data_loader, val_data_loader)
+        self.gan_recon = getattr(cfg.trainer, 'gan_recon', False)
+        self.best_fid_a = None
+        self.best_fid_b = None
+
+    def _init_loss(self, cfg):
+        r"""Initialize loss terms. In MUNIT, we have several loss terms
+        including the GAN loss, the image reconstruction loss, the content
+        reconstruction loss, the style reconstruction loss, the cycle
+        reconstruction loss. We also have an optional perceptual loss. A user
+        can choose to have gradient penalty or consistency regularization too.
+
+        Args:
+            cfg (obj): Global configuration.
+        """
+        self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode)
+        self.criteria['kl'] = GaussianKLLoss()
+        self.criteria['image_recon'] = torch.nn.L1Loss()
+        if getattr(cfg.trainer.loss_weight, 'perceptual', 0) > 0:
+            self.criteria['perceptual'] = \
+                PerceptualLoss(network=cfg.trainer.perceptual_mode,
+                               layers=cfg.trainer.perceptual_layers)
+
+        for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items():
+            if loss_weight > 0:
+                self.weights[loss_name] = loss_weight
+
+    def gen_forward(self, data):
+        r"""Compute the loss for MUNIT generator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        cycle_recon = 'cycle_recon' in self.weights
+        image_recon = 'image_recon' in self.weights
+        perceptual = 'perceptual' in self.weights
+        within_latent_recon = 'style_recon_within' in self.weights or \
+                              'content_recon_within' in self.weights
+
+        net_G_output = self.net_G(data,
+                                  image_recon=image_recon,
+                                  cycle_recon=cycle_recon,
+                                  within_latent_recon=within_latent_recon)
+
+        # Differentiable augmentation.
+        keys = ['images_ab', 'images_ba']
+        if self.gan_recon:
+            keys += ['images_aa', 'images_bb']
+        net_D_output = self.net_D(data,
+                                  apply_diff_aug(
+                                      net_G_output, keys, self.aug_policy),
+                                  real=False,
+                                  gan_recon=self.gan_recon)
+
+        self._time_before_loss()
+
+        # GAN loss
+        if self.gan_recon:
+            self.gen_losses['gan_a'] = \
+                0.5 * (self.criteria['gan'](net_D_output['out_ba'],
+                                            True, dis_update=False) +
+                       self.criteria['gan'](net_D_output['out_aa'],
+                                            True, dis_update=False))
+            self.gen_losses['gan_b'] = \
+                0.5 * (self.criteria['gan'](net_D_output['out_ab'],
+                                            True, dis_update=False) +
+                       self.criteria['gan'](net_D_output['out_bb'],
+                                            True, dis_update=False))
+        else:
+            self.gen_losses['gan_a'] = self.criteria['gan'](
+                net_D_output['out_ba'], True, dis_update=False)
+            self.gen_losses['gan_b'] = self.criteria['gan'](
+                net_D_output['out_ab'], True, dis_update=False)
+        self.gen_losses['gan'] = \
+            self.gen_losses['gan_a'] + self.gen_losses['gan_b']
+
+        # Perceptual loss
+        if perceptual:
+            self.gen_losses['perceptual_a'] = \
+                self.criteria['perceptual'](net_G_output['images_ab'],
+                                            data['images_a'])
+            self.gen_losses['perceptual_b'] = \
+                self.criteria['perceptual'](net_G_output['images_ba'],
+                                            data['images_b'])
+            self.gen_losses['perceptual'] = \
+                self.gen_losses['perceptual_a'] + \
+                self.gen_losses['perceptual_b']
+
+        # Image reconstruction loss
+        if image_recon:
+            self.gen_losses['image_recon'] = \
+                self.criteria['image_recon'](net_G_output['images_aa'],
+                                             data['images_a']) + \
+                self.criteria['image_recon'](net_G_output['images_bb'],
+                                             data['images_b'])
+
+        # Style reconstruction loss
+        self.gen_losses['style_recon_a'] = torch.abs(
+            net_G_output['style_ba'] -
+            net_G_output['style_a_rand']).mean()
+        self.gen_losses['style_recon_b'] = torch.abs(
+            net_G_output['style_ab'] -
+            net_G_output['style_b_rand']).mean()
+        self.gen_losses['style_recon'] = \
+            self.gen_losses['style_recon_a'] + self.gen_losses['style_recon_b']
+
+        if within_latent_recon:
+            self.gen_losses['style_recon_aa'] = torch.abs(
+                net_G_output['style_aa'] -
+                net_G_output['style_a'].detach()).mean()
+            self.gen_losses['style_recon_bb'] = torch.abs(
+                net_G_output['style_bb'] -
+                net_G_output['style_b'].detach()).mean()
+            self.gen_losses['style_recon_within'] = \
+                self.gen_losses['style_recon_aa'] + \
+                self.gen_losses['style_recon_bb']
+
+        # Content reconstruction loss
+        self.gen_losses['content_recon_a'] = torch.abs(
+            net_G_output['content_ab'] -
+            net_G_output['content_a'].detach()).mean()
+        self.gen_losses['content_recon_b'] = torch.abs(
+            net_G_output['content_ba'] -
+            net_G_output['content_b'].detach()).mean()
+        self.gen_losses['content_recon'] = \
+            self.gen_losses['content_recon_a'] + \
+            self.gen_losses['content_recon_b']
+
+        if within_latent_recon:
+            self.gen_losses['content_recon_aa'] = torch.abs(
+                net_G_output['content_aa'] -
+                net_G_output['content_a'].detach()).mean()
+            self.gen_losses['content_recon_bb'] = torch.abs(
+                net_G_output['content_bb'] -
+                net_G_output['content_b'].detach()).mean()
+            self.gen_losses['content_recon_within'] = \
+                self.gen_losses['content_recon_aa'] + \
+                self.gen_losses['content_recon_bb']
+
+        # KL loss
+        self.gen_losses['kl'] = \
+            self.criteria['kl'](net_G_output['style_a']) + \
+            self.criteria['kl'](net_G_output['style_b'])
+
+        # Cycle reconstruction loss
+        if cycle_recon:
+            self.gen_losses['cycle_recon'] = \
+                torch.abs(net_G_output['images_aba'] -
+                          data['images_a']).mean() + \
+                torch.abs(net_G_output['images_bab'] -
+                          data['images_b']).mean()
+
+        # Compute total loss
+        total_loss = self._get_total_loss(gen_forward=True)
+        return total_loss
+
+    def dis_forward(self, data):
+        r"""Compute the loss for MUNIT discriminator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        with torch.no_grad():
+            net_G_output = self.net_G(data,
+                                      image_recon=self.gan_recon,
+                                      latent_recon=False,
+                                      cycle_recon=False,
+                                      within_latent_recon=False)
+        net_G_output['images_ba'].requires_grad = True
+        net_G_output['images_ab'].requires_grad = True
+
+        # Differentiable augmentation.
+        keys_fake = ['images_ab', 'images_ba']
+        if self.gan_recon:
+            keys_fake += ['images_aa', 'images_bb']
+        keys_real = ['images_a', 'images_b']
+
+        net_D_output = self.net_D(
+            apply_diff_aug(data, keys_real, self.aug_policy),
+            apply_diff_aug(net_G_output, keys_fake, self.aug_policy),
+            gan_recon=self.gan_recon)
+
+        self._time_before_loss()
+
+        # GAN loss.
+        self.dis_losses['gan_a'] = \
+            self.criteria['gan'](net_D_output['out_a'], True) + \
+            self.criteria['gan'](net_D_output['out_ba'], False)
+        self.dis_losses['gan_b'] = \
+            self.criteria['gan'](net_D_output['out_b'], True) + \
+            self.criteria['gan'](net_D_output['out_ab'], False)
+        self.dis_losses['gan'] = \
+            self.dis_losses['gan_a'] + self.dis_losses['gan_b']
+
+        # Consistency regularization.
+        self.dis_losses['consistency_reg'] = \
+            torch.tensor(0., device=torch.device('cuda'))
+        if 'consistency_reg' in self.weights:
+            data_aug, net_G_output_aug = {}, {}
+            data_aug['images_a'] = random_shift(data['images_a'].flip(-1))
+            data_aug['images_b'] = random_shift(data['images_b'].flip(-1))
+            net_G_output_aug['images_ab'] = \
+                random_shift(net_G_output['images_ab'].flip(-1))
+            net_G_output_aug['images_ba'] = \
+                random_shift(net_G_output['images_ba'].flip(-1))
+            net_D_output_aug = self.net_D(data_aug, net_G_output_aug)
+            feature_names = ['fea_ba', 'fea_ab',
+                             'fea_a', 'fea_b']
+            for feature_name in feature_names:
+                self.dis_losses['consistency_reg'] += \
+                    torch.pow(net_D_output_aug[feature_name] -
+                              net_D_output[feature_name], 2).mean()
+
+        # Compute total loss
+        total_loss = self._get_total_loss(gen_forward=False)
+        return total_loss
+
+    def _get_visualizations(self, data):
+        r"""Compute visualization image.
+
+        Args:
+            data (dict): The current batch.
+        """
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G_for_evaluation = self.net_G.module.averaged_model
+        else:
+            net_G_for_evaluation = self.net_G
+        with torch.no_grad():
+            net_G_output = net_G_for_evaluation(data, random_style=False)
+            net_G_output_random = net_G_for_evaluation(data)
+            vis_images = [data['images_a'],
+                          data['images_b'],
+                          net_G_output['images_aa'],
+                          net_G_output['images_bb'],
+                          net_G_output['images_ab'],
+                          net_G_output_random['images_ab'],
+                          net_G_output['images_ba'],
+                          net_G_output_random['images_ba'],
+                          net_G_output['images_aba'],
+                          net_G_output['images_bab']]
+            return vis_images
+
+    def write_metrics(self):
+        r"""Compute metrics and save them to tensorboard"""
+        cur_fid_a, cur_fid_b = self._compute_fid()
+        if self.best_fid_a is not None:
+            self.best_fid_a = min(self.best_fid_a, cur_fid_a)
+        else:
+            self.best_fid_a = cur_fid_a
+        if self.best_fid_b is not None:
+            self.best_fid_b = min(self.best_fid_b, cur_fid_b)
+        else:
+            self.best_fid_b = cur_fid_b
+        self._write_to_meters({'FID_a': cur_fid_a,
+                               'best_FID_a': self.best_fid_a,
+                               'FID_b': cur_fid_b,
+                               'best_FID_b': self.best_fid_b},
+                              self.metric_meters)
+        self._flush_meters(self.metric_meters)
+
+    def _compute_fid(self):
+        r"""Compute FID for both domains.
+        """
+        self.net_G.eval()
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G_for_evaluation = self.net_G.module.averaged_model
+        else:
+            net_G_for_evaluation = self.net_G
+        fid_a_path = self._get_save_path('fid_a', 'npy')
+        fid_b_path = self._get_save_path('fid_b', 'npy')
+        fid_value_a = compute_fid(fid_a_path, self.val_data_loader,
+                                  net_G_for_evaluation, 'images_a', 'images_ba')
+        fid_value_b = compute_fid(fid_b_path, self.val_data_loader,
+                                  net_G_for_evaluation, 'images_b', 'images_ab')
+        print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format(
+            self.current_epoch, self.current_iteration,
+            fid_value_a, fid_value_b))
+        return fid_value_a, fid_value_b
diff --git a/imaginaire/trainers/pix2pixHD.py b/imaginaire/trainers/pix2pixHD.py
new file mode 100644
index 0000000000000000000000000000000000000000..af8ed264adb0f5e2be61ce258ea3221b3b365f6a
--- /dev/null
+++ b/imaginaire/trainers/pix2pixHD.py
@@ -0,0 +1,202 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+
+import torch
+
+from imaginaire.evaluation import compute_fid
+from imaginaire.losses import FeatureMatchingLoss, GANLoss, PerceptualLoss
+from imaginaire.model_utils.pix2pixHD import cluster_features, get_edges
+from imaginaire.trainers.spade import Trainer as SPADETrainer
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.misc import to_cuda
+
+
+class Trainer(SPADETrainer):
+    r"""Initialize pix2pixHD trainer.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                 train_data_loader, val_data_loader):
+        super(Trainer, self).__init__(cfg, net_G, net_D, opt_G,
+                                      opt_D, sch_G, sch_D,
+                                      train_data_loader, val_data_loader)
+
+    def _assign_criteria(self, name, criterion, weight):
+        r"""Assign training loss terms.
+
+        Args:
+            name (str): Loss name
+            criterion (obj): Loss object.
+            weight (float): Loss weight. It should be non-negative.
+        """
+        self.criteria[name] = criterion
+        self.weights[name] = weight
+
+    def _init_loss(self, cfg):
+        r"""Initialize training loss terms. In pix2pixHD, there are three
+        loss terms: GAN loss, feature matching loss, and perceptual loss.
+
+        Args:
+            cfg (obj): Global configuration.
+        """
+        self.criteria = dict()
+        self.weights = dict()
+        trainer_cfg = cfg.trainer
+        loss_weight = cfg.trainer.loss_weight
+        # GAN loss and feature matching loss.
+        self._assign_criteria('GAN',
+                              GANLoss(trainer_cfg.gan_mode),
+                              loss_weight.gan)
+        self._assign_criteria('FeatureMatching',
+                              FeatureMatchingLoss(),
+                              loss_weight.feature_matching)
+        self._assign_criteria('Perceptual',
+                              PerceptualLoss(
+                                  network=cfg.trainer.perceptual_loss.mode,
+                                  layers=cfg.trainer.perceptual_loss.layers,
+                                  weights=cfg.trainer.perceptual_loss.weights),
+                              loss_weight.perceptual)
+
+    def _start_of_iteration(self, data, current_iteration):
+        r"""Things to do before an iteration.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_iteration (int): Current number of iteration.
+        """
+        return self.pre_process(data)
+
+    def gen_forward(self, data):
+        r"""Compute the loss for pix2pixHD generator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        net_G_output = self.net_G(data)
+        net_D_output = self.net_D(data, net_G_output)
+
+        self._time_before_loss()
+
+        output_fake = self._get_outputs(net_D_output, real=False)
+        self.gen_losses['GAN'] = \
+            self.criteria['GAN'](output_fake, True, dis_update=False)
+
+        self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching'](
+            net_D_output['fake_features'], net_D_output['real_features'])
+
+        if hasattr(self.cfg.trainer, 'perceptual_loss'):
+            self.gen_losses['Perceptual'] = self.criteria['Perceptual'](
+                net_G_output['fake_images'], data['images'])
+
+        total_loss = self.gen_losses['GAN'].new_tensor([0])
+        for key in self.criteria:
+            total_loss += self.gen_losses[key] * self.weights[key]
+
+        self.gen_losses['total'] = total_loss
+        return total_loss
+
+    def dis_forward(self, data):
+        r"""Compute the loss for pix2pixHD discriminator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        with torch.no_grad():
+            net_G_output = self.net_G(data)
+            net_G_output['fake_images'] = net_G_output['fake_images'].detach()
+        net_D_output = self.net_D(data, net_G_output)
+
+        self._time_before_loss()
+
+        output_fake = self._get_outputs(net_D_output, real=False)
+        output_real = self._get_outputs(net_D_output, real=True)
+        fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True)
+        true_loss = self.criteria['GAN'](output_real, True, dis_update=True)
+        self.dis_losses['GAN'] = fake_loss + true_loss
+        total_loss = self.dis_losses['GAN'] * self.weights['GAN']
+        self.dis_losses['total'] = total_loss
+        return total_loss
+
+    def pre_process(self, data):
+        r"""Data pre-processing step for the pix2pixHD method. It takes a
+        dictionary as input where the dictionary contains a label field. The
+        label field is the concatenation of the segmentation mask and the
+        instance map. In this function, we will replace the instance map with
+        an edge map. We will also add a "instance_maps" field to the dictionary.
+
+        Args:
+            data (dict): Input dictionary.
+            data['label']: Input label map where the last channel is the
+                instance map.
+        """
+        data = to_cuda(data)
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G = self.net_G.module.module
+        else:
+            net_G = self.net_G.module
+        if net_G.contain_instance_map:
+            inst_maps = data['label'][:, -1:]
+            edge_maps = get_edges(inst_maps)
+            data['instance_maps'] = inst_maps.clone()
+            data['label'][:, -1:] = edge_maps
+        return data
+
+    def _pre_save_checkpoint(self):
+        r"""Implement the things you want to do before saving the checkpoints.
+        For example, you can compute the K-mean features (pix2pixHD) before
+        saving the model weights to the checkponts.
+        """
+        if hasattr(self.cfg.gen, 'enc'):
+            if self.cfg.trainer.model_average_config.enabled:
+                net_E = self.net_G.module.averaged_model.encoder
+            else:
+                net_E = self.net_G.module.encoder
+            is_cityscapes = getattr(self.cfg.gen, 'is_cityscapes', False)
+            cluster_features(self.cfg, self.val_data_loader,
+                             net_E,
+                             self.pre_process,
+                             is_cityscapes)
+
+    def _compute_fid(self):
+        r"""We will compute FID for the regular model using the eval mode.
+        For the moving average model, we will use the eval mode.
+        """
+        self.net_G.eval()
+        net_G_for_evaluation = \
+            functools.partial(self.net_G, random_style=True)
+        regular_fid_path = self._get_save_path('regular_fid', 'npy')
+        regular_fid_value = compute_fid(regular_fid_path,
+                                        self.val_data_loader,
+                                        net_G_for_evaluation,
+                                        preprocess=self.pre_process)
+        print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format(
+            self.current_epoch, self.current_iteration, regular_fid_value))
+        if self.cfg.trainer.model_average_config.enabled:
+            avg_net_G_for_evaluation = \
+                functools.partial(self.net_G.module.averaged_model,
+                                  random_style=True)
+            fid_path = self._get_save_path('average_fid', 'npy')
+            fid_value = compute_fid(fid_path, self.val_data_loader,
+                                    avg_net_G_for_evaluation,
+                                    preprocess=self.pre_process)
+            print('Epoch {:05}, Iteration {:09}, FID {}'.format(
+                self.current_epoch, self.current_iteration, fid_value))
+            self.net_G.float()
+            return regular_fid_value, fid_value
+        else:
+            self.net_G.float()
+            return regular_fid_value
diff --git a/imaginaire/trainers/spade.py b/imaginaire/trainers/spade.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6132f20d019c18d3bd8670f00ed8b5a7a389e85
--- /dev/null
+++ b/imaginaire/trainers/spade.py
@@ -0,0 +1,282 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+import math
+
+import torch
+import torch.nn.functional as F
+
+from imaginaire.evaluation import compute_fid
+from imaginaire.losses import (FeatureMatchingLoss, GANLoss, GaussianKLLoss,
+                               PerceptualLoss)
+from imaginaire.trainers.base import BaseTrainer
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.model_average import reset_batch_norm, \
+    calibrate_batch_norm_momentum
+from imaginaire.utils.misc import split_labels, to_device
+from imaginaire.utils.visualization import tensor2label
+
+
+class Trainer(BaseTrainer):
+    r"""Initialize SPADE trainer.
+
+    Args:
+        cfg (Config): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self,
+                 cfg,
+                 net_G,
+                 net_D,
+                 opt_G,
+                 opt_D,
+                 sch_G,
+                 sch_D,
+                 train_data_loader,
+                 val_data_loader):
+        super(Trainer, self).__init__(cfg, net_G, net_D, opt_G,
+                                      opt_D, sch_G, sch_D,
+                                      train_data_loader, val_data_loader)
+        if cfg.data.type == 'imaginaire.datasets.paired_videos':
+            self.video_mode = True
+        else:
+            self.video_mode = False
+
+    def _init_loss(self, cfg):
+        r"""Initialize loss terms.
+
+        Args:
+            cfg (obj): Global configuration.
+        """
+        self.criteria['GAN'] = GANLoss(cfg.trainer.gan_mode)
+        self.weights['GAN'] = cfg.trainer.loss_weight.gan
+        # Setup the perceptual loss. Note that perceptual loss can run in
+        # fp16 mode for additional speed. We find that running on fp16 mode
+        # leads to improve training speed while maintaining the same accuracy.
+        if hasattr(cfg.trainer, 'perceptual_loss'):
+            self.criteria['Perceptual'] = \
+                PerceptualLoss(
+                    network=cfg.trainer.perceptual_loss.mode,
+                    layers=cfg.trainer.perceptual_loss.layers,
+                    weights=cfg.trainer.perceptual_loss.weights)
+            self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual
+        # Setup the feature matching loss.
+        self.criteria['FeatureMatching'] = FeatureMatchingLoss()
+        self.weights['FeatureMatching'] = \
+            cfg.trainer.loss_weight.feature_matching
+        # Setup the Gaussian KL divergence loss.
+        self.criteria['GaussianKL'] = GaussianKLLoss()
+        self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl
+
+    def _start_of_iteration(self, data, current_iteration):
+        r"""Model specific custom start of iteration process. We will do two
+        things. First, put all the data to GPU. Second, we will resize the
+        input so that it becomes multiple of the factor for bug-free
+        convolutional operations. This factor is given by the yaml file.
+        E.g., base = getattr(self.net_G, 'base', 32)
+
+        Args:
+            data (dict): The current batch.
+            current_iteration (int): The iteration number of the current batch.
+        """
+        data = to_device(data, 'cuda')
+        data = self._resize_data(data)
+        return data
+
+    def gen_forward(self, data):
+        r"""Compute the loss for SPADE generator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        net_G_output = self.net_G(data)
+        net_D_output = self.net_D(data, net_G_output)
+
+        self._time_before_loss()
+
+        output_fake = self._get_outputs(net_D_output, real=False)
+        self.gen_losses['GAN'] = self.criteria['GAN'](output_fake, True, dis_update=False)
+
+        self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching'](
+            net_D_output['fake_features'], net_D_output['real_features'])
+
+        if self.net_G_module.use_style_encoder:
+            self.gen_losses['GaussianKL'] = \
+                self.criteria['GaussianKL'](net_G_output['mu'],
+                                            net_G_output['logvar'])
+        else:
+            self.gen_losses['GaussianKL'] = \
+                self.gen_losses['GAN'].new_tensor([0])
+
+        if hasattr(self.cfg.trainer, 'perceptual_loss'):
+            self.gen_losses['Perceptual'] = self.criteria['Perceptual'](
+                net_G_output['fake_images'], data['images'])
+
+        total_loss = self.gen_losses['GAN'].new_tensor([0])
+        for key in self.criteria:
+            total_loss += self.gen_losses[key] * self.weights[key]
+
+        self.gen_losses['total'] = total_loss
+        return total_loss
+
+    def dis_forward(self, data):
+        r"""Compute the loss for SPADE discriminator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        with torch.no_grad():
+            net_G_output = self.net_G(data)
+            net_G_output['fake_images'] = net_G_output['fake_images'].detach()
+        net_D_output = self.net_D(data, net_G_output)
+
+        self._time_before_loss()
+
+        output_fake = self._get_outputs(net_D_output, real=False)
+        output_real = self._get_outputs(net_D_output, real=True)
+        fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True)
+        true_loss = self.criteria['GAN'](output_real, True, dis_update=True)
+        self.dis_losses['GAN/fake'] = fake_loss
+        self.dis_losses['GAN/true'] = true_loss
+        self.dis_losses['GAN'] = fake_loss + true_loss
+        total_loss = self.dis_losses['GAN'] * self.weights['GAN']
+        self.dis_losses['total'] = total_loss
+        return total_loss
+
+    def _get_visualizations(self, data):
+        r"""Compute visualization image. We will first recalculate the batch
+        statistics for the moving average model.
+
+        Args:
+            data (dict): The current batch.
+        """
+        self.recalculate_batch_norm_statistics(
+            self.train_data_loader)
+        with torch.no_grad():
+            label_lengths = self.train_data_loader.dataset.get_label_lengths()
+            labels = split_labels(data['label'], label_lengths)
+            # Get visualization of the segmentation mask.
+            vis_images = list()
+            vis_images.append(data['images'])
+            net_G_output = self.net_G(data, random_style=True)
+            # print(labels.keys())
+            for key in labels.keys():
+                if 'seg' in key:
+                    segmaps = tensor2label(labels[key], label_lengths[key], output_normalized_tensor=True)
+                    segmaps = torch.cat([x.unsqueeze(0) for x in segmaps], 0)
+                    vis_images.append(segmaps)
+                if 'edge' in key:
+                    edgemaps = torch.cat((labels[key], labels[key], labels[key]), 1)
+                    vis_images.append(edgemaps)
+
+            vis_images.append(net_G_output['fake_images'])
+            if self.cfg.trainer.model_average_config.enabled:
+                net_G_model_average_output = \
+                    self.net_G.module.averaged_model(data, random_style=True)
+                vis_images.append(net_G_model_average_output['fake_images'])
+        return vis_images
+
+    def recalculate_batch_norm_statistics(self, data_loader):
+        r"""Update the statistics in the moving average model.
+
+        Args:
+            data_loader (pytorch data loader): Data loader for estimating the
+                statistics.
+        """
+        if not self.cfg.trainer.model_average_config.enabled:
+            return
+        model_average_iteration = \
+            self.cfg.trainer.model_average_config.num_batch_norm_estimation_iterations
+        if model_average_iteration == 0:
+            return
+        with torch.no_grad():
+            # Accumulate bn stats..
+            self.net_G.module.averaged_model.train()
+            # Reset running stats.
+            self.net_G.module.averaged_model.apply(reset_batch_norm)
+            for cal_it, cal_data in enumerate(data_loader):
+                if cal_it >= model_average_iteration:
+                    print('Done with {} iterations of updating batch norm '
+                          'statistics'.format(model_average_iteration))
+                    break
+                # cal_data = to_device(cal_data, 'cuda')
+                cal_data = self._start_of_iteration(cal_data, 0)
+                # Averaging over all batches
+                self.net_G.module.averaged_model.apply(
+                    calibrate_batch_norm_momentum)
+                self.net_G.module.averaged_model(cal_data)
+
+    def write_metrics(self):
+        r"""If moving average model presents, we have two meters one for
+        regular FID and one for average FID. If no moving average model,
+        we just report average FID.
+        """
+        if self.cfg.trainer.model_average_config.enabled:
+            regular_fid, average_fid = self._compute_fid()
+            metric_dict = {'FID/average': average_fid, 'FID/regular': regular_fid}
+            self._write_to_meters(metric_dict, self.metric_meters, reduce=False)
+        else:
+            regular_fid = self._compute_fid()
+            metric_dict = {'FID/regular': regular_fid}
+            self._write_to_meters(metric_dict, self.metric_meters, reduce=False)
+        self._flush_meters(self.metric_meters)
+
+    def _compute_fid(self):
+        r"""We will compute FID for the regular model using the eval mode.
+        For the moving average model, we will use the eval mode.
+        """
+        self.net_G.eval()
+        net_G_for_evaluation = \
+            functools.partial(self.net_G, random_style=True)
+        regular_fid_path = self._get_save_path('regular_fid', 'npy')
+        preprocess = \
+            functools.partial(self._start_of_iteration, current_iteration=0)
+
+        regular_fid_value = compute_fid(regular_fid_path,
+                                        self.val_data_loader,
+                                        net_G_for_evaluation,
+                                        preprocess=preprocess)
+        print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format(
+            self.current_epoch, self.current_iteration, regular_fid_value))
+        if self.cfg.trainer.model_average_config.enabled:
+            avg_net_G_for_evaluation = \
+                functools.partial(self.net_G.module.averaged_model,
+                                  random_style=True)
+            fid_path = self._get_save_path('average_fid', 'npy')
+            fid_value = compute_fid(fid_path, self.val_data_loader,
+                                    avg_net_G_for_evaluation,
+                                    preprocess=preprocess)
+            print('Epoch {:05}, Iteration {:09}, FID {}'.format(
+                self.current_epoch, self.current_iteration, fid_value))
+            self.net_G.float()
+            return regular_fid_value, fid_value
+        else:
+            self.net_G.float()
+            return regular_fid_value
+
+    def _resize_data(self, data):
+        r"""Resize input label maps and images so that it can be properly
+        generated by the generator.
+
+        Args:
+            data (dict): Input dictionary contains 'label' and 'image fields.
+        """
+        base = getattr(self.net_G, 'base', 32)
+        sy = math.floor(data['label'].size()[2] * 1.0 // base) * base
+        sx = math.floor(data['label'].size()[3] * 1.0 // base) * base
+        data['label'] = F.interpolate(
+            data['label'], size=[sy, sx], mode='nearest')
+        if 'images' in data.keys():
+            data['images'] = F.interpolate(
+                data['images'], size=[sy, sx], mode='bicubic')
+        return data
diff --git a/imaginaire/trainers/unit.py b/imaginaire/trainers/unit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b3a4e0926acd2affcd96c6a9a09463c5056951c
--- /dev/null
+++ b/imaginaire/trainers/unit.py
@@ -0,0 +1,210 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+
+from imaginaire.evaluation import compute_fid
+from imaginaire.losses import GANLoss, PerceptualLoss  # GaussianKLLoss
+from imaginaire.trainers.base import BaseTrainer
+
+
+class Trainer(BaseTrainer):
+    r"""Reimplementation of the UNIT (https://arxiv.org/abs/1703.00848)
+    algorithm.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                 train_data_loader, val_data_loader):
+        super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                         train_data_loader, val_data_loader)
+        self.best_fid_a = None
+        self.best_fid_b = None
+
+    def _init_loss(self, cfg):
+        r"""Initialize loss terms. In UNIT, we have several loss terms
+        including the GAN loss, the image reconstruction loss, the cycle
+        reconstruction loss, and the gaussian kl loss. We also have an
+        optional perceptual loss. A user can choose to have the gradient
+        penalty loss too.
+
+        Args:
+            cfg (obj): Global configuration.
+        """
+        self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode)
+        # self.criteria['gaussian_kl'] = GaussianKLLoss()
+        self.criteria['image_recon'] = nn.L1Loss()
+        self.criteria['cycle_recon'] = nn.L1Loss()
+        if getattr(cfg.trainer.loss_weight, 'perceptual', 0) > 0:
+            self.criteria['perceptual'] = \
+                PerceptualLoss(network=cfg.trainer.perceptual_mode,
+                               layers=cfg.trainer.perceptual_layers)
+
+        for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items():
+            if loss_weight > 0:
+                self.weights[loss_name] = loss_weight
+
+    def gen_forward(self, data):
+        r"""Compute the loss for UNIT generator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        cycle_recon = 'cycle_recon' in self.weights
+        perceptual = 'perceptual' in self.weights
+        net_G_output = self.net_G(data, cycle_recon=cycle_recon)
+        net_D_output = self.net_D(data, net_G_output, real=False)
+
+        self._time_before_loss()
+
+        # GAN loss
+        self.gen_losses['gan_a'] = self.criteria['gan'](
+            net_D_output['out_ba'], True, dis_update=False)
+        self.gen_losses['gan_b'] = self.criteria['gan'](
+            net_D_output['out_ab'], True, dis_update=False)
+        self.gen_losses['gan'] = \
+            self.gen_losses['gan_a'] + self.gen_losses['gan_b']
+
+        # Perceptual loss
+        if perceptual:
+            self.gen_losses['perceptual_a'] = \
+                self.criteria['perceptual'](net_G_output['images_ab'],
+                                            data['images_a'])
+            self.gen_losses['perceptual_b'] = \
+                self.criteria['perceptual'](net_G_output['images_ba'],
+                                            data['images_b'])
+            self.gen_losses['perceptual'] = \
+                self.gen_losses['perceptual_a'] + \
+                self.gen_losses['perceptual_b']
+
+        # Image reconstruction loss
+        self.gen_losses['image_recon'] = \
+            self.criteria['image_recon'](net_G_output['images_aa'],
+                                         data['images_a']) + \
+            self.criteria['image_recon'](net_G_output['images_bb'],
+                                         data['images_b'])
+
+        """
+        # KL loss
+        self.gen_losses['gaussian_kl'] = \
+            self.criteria['gaussian_kl'](net_G_output['content_mu_a']) + \
+            self.criteria['gaussian_kl'](net_G_output['content_mu_b']) + \
+            self.criteria['gaussian_kl'](net_G_output['content_mu_a_recon']) + \
+            self.criteria['gaussian_kl'](net_G_output['content_mu_b_recon'])
+        """
+
+        # Cycle reconstruction loss
+        if cycle_recon:
+            self.gen_losses['cycle_recon_aba'] = \
+                self.criteria['cycle_recon'](net_G_output['images_aba'],
+                                             data['images_a'])
+            self.gen_losses['cycle_recon_bab'] = \
+                self.criteria['cycle_recon'](net_G_output['images_bab'],
+                                             data['images_b'])
+            self.gen_losses['cycle_recon'] = \
+                self.gen_losses['cycle_recon_aba'] + \
+                self.gen_losses['cycle_recon_bab']
+
+        # Compute total loss
+        total_loss = self._get_total_loss(gen_forward=True)
+        return total_loss
+
+    def dis_forward(self, data):
+        r"""Compute the loss for UNIT discriminator.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        with torch.no_grad():
+            net_G_output = self.net_G(data, image_recon=False,
+                                      cycle_recon=False)
+        net_G_output['images_ba'].requires_grad = True
+        net_G_output['images_ab'].requires_grad = True
+        net_D_output = self.net_D(data, net_G_output)
+
+        self._time_before_loss()
+
+        # GAN loss.
+        self.dis_losses['gan_a'] = \
+            self.criteria['gan'](net_D_output['out_a'], True) + \
+            self.criteria['gan'](net_D_output['out_ba'], False)
+        self.dis_losses['gan_b'] = \
+            self.criteria['gan'](net_D_output['out_b'], True) + \
+            self.criteria['gan'](net_D_output['out_ab'], False)
+        self.dis_losses['gan'] = \
+            self.dis_losses['gan_a'] + self.dis_losses['gan_b']
+
+        # Compute total loss
+        total_loss = self._get_total_loss(gen_forward=False)
+        return total_loss
+
+    def _get_visualizations(self, data):
+        r"""Compute visualization image.
+
+        Args:
+            data (dict): The current batch.
+        """
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G_for_evaluation = self.net_G.module.averaged_model
+        else:
+            net_G_for_evaluation = self.net_G
+        with torch.no_grad():
+            net_G_output = net_G_for_evaluation(data)
+            vis_images = [data['images_a'],
+                          data['images_b'],
+                          net_G_output['images_aa'],
+                          net_G_output['images_bb'],
+                          net_G_output['images_ab'],
+                          net_G_output['images_ba'],
+                          net_G_output['images_aba'],
+                          net_G_output['images_bab']]
+            return vis_images
+
+    def write_metrics(self):
+        r"""Compute metrics and save them to tensorboard"""
+        cur_fid_a, cur_fid_b = self._compute_fid()
+        if self.best_fid_a is not None:
+            self.best_fid_a = min(self.best_fid_a, cur_fid_a)
+        else:
+            self.best_fid_a = cur_fid_a
+        if self.best_fid_b is not None:
+            self.best_fid_b = min(self.best_fid_b, cur_fid_b)
+        else:
+            self.best_fid_b = cur_fid_b
+        self._write_to_meters({'FID_a': cur_fid_a,
+                               'best_FID_a': self.best_fid_a,
+                               'FID_b': cur_fid_b,
+                               'best_FID_b': self.best_fid_b},
+                              self.metric_meters)
+        self._flush_meters(self.metric_meters)
+
+    def _compute_fid(self):
+        r"""Compute FID for both domains.
+        """
+        self.net_G.eval()
+        if self.cfg.trainer.model_average_config.enabled:
+            net_G_for_evaluation = self.net_G.module.averaged_model
+        else:
+            net_G_for_evaluation = self.net_G
+        fid_a_path = self._get_save_path('fid_a', 'npy')
+        fid_b_path = self._get_save_path('fid_b', 'npy')
+        fid_value_a = compute_fid(fid_a_path, self.val_data_loader,
+                                  net_G_for_evaluation, 'images_a', 'images_ba')
+        fid_value_b = compute_fid(fid_b_path, self.val_data_loader,
+                                  net_G_for_evaluation, 'images_b', 'images_ab')
+        print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format(
+            self.current_epoch, self.current_iteration,
+            fid_value_a, fid_value_b))
+        return fid_value_a, fid_value_b
diff --git a/imaginaire/trainers/vid2vid.py b/imaginaire/trainers/vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd6b46fa64f6738c21636f785470e25120324b57
--- /dev/null
+++ b/imaginaire/trainers/vid2vid.py
@@ -0,0 +1,913 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+from torch.cuda.amp import autocast
+import imageio
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from imaginaire.evaluation.fid import compute_fid
+from imaginaire.losses import (FeatureMatchingLoss, FlowLoss, GANLoss,
+                               PerceptualLoss)
+from imaginaire.model_utils.fs_vid2vid import (concat_frames, detach,
+                                               get_fg_mask,
+                                               pre_process_densepose, resample)
+from imaginaire.trainers.base import BaseTrainer
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.misc import get_nested_attr, split_labels, to_cuda
+from imaginaire.utils.visualization import (tensor2flow, tensor2im, tensor2label)
+from imaginaire.utils.visualization.pose import tensor2pose
+
+
+class Trainer(BaseTrainer):
+    r"""Initialize vid2vid trainer.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                 train_data_loader, val_data_loader):
+        super(Trainer, self).__init__(cfg, net_G, net_D, opt_G,
+                                      opt_D, sch_G, sch_D,
+                                      train_data_loader, val_data_loader)
+        # Below is for testing setting, the FID computation during training
+        # is just for getting a quick idea of the performance. It does not
+        # equal to the final performance evaluation.
+        # Below, we will determine how many videos that we want to do
+        # evaluation, and the length of each video.
+        # It is better to keep the number of videos to be multiple of 8 so
+        # that all the GPUs in a node will contribute equally to the
+        # evaluation. None of them is idol.
+        self.sample_size = (
+            getattr(cfg.trainer, 'num_videos_to_test', 64),
+            getattr(cfg.trainer, 'num_frames_per_video', 10)
+        )
+
+        self.sequence_length = 1
+        if not self.is_inference:
+            self.train_dataset = self.train_data_loader.dataset
+            self.sequence_length_max = \
+                min(getattr(cfg.data.train, 'max_sequence_length', 100),
+                    self.train_dataset.sequence_length_max)
+        self.Tensor = torch.cuda.FloatTensor
+        self.has_fg = getattr(cfg.data, 'has_foreground', False)
+
+        self.net_G_output = self.data_prev = None
+        self.net_G_module = self.net_G.module
+        if self.cfg.trainer.model_average_config.enabled:
+            self.net_G_module = self.net_G_module.module
+
+    def _assign_criteria(self, name, criterion, weight):
+        r"""Assign training loss terms.
+
+        Args:
+            name (str): Loss name
+            criterion (obj): Loss object.
+            weight (float): Loss weight. It should be non-negative.
+        """
+        self.criteria[name] = criterion
+        self.weights[name] = weight
+
+    def _init_loss(self, cfg):
+        r"""Initialize training loss terms. In vid2vid, in addition to
+        the GAN loss, feature matching loss, and perceptual loss used in
+        pix2pixHD, we also add temporal GAN (and feature matching) loss,
+        and flow warping loss. Optionally, we can also add an additional
+        face discriminator for the face region.
+
+        Args:
+            cfg (obj): Global configuration.
+        """
+        self.criteria = dict()
+        self.weights = dict()
+        trainer_cfg = cfg.trainer
+        loss_weight = cfg.trainer.loss_weight
+
+        # GAN loss and feature matching loss.
+        self._assign_criteria('GAN',
+                              GANLoss(trainer_cfg.gan_mode),
+                              loss_weight.gan)
+        self._assign_criteria('FeatureMatching',
+                              FeatureMatchingLoss(),
+                              loss_weight.feature_matching)
+
+        # Perceptual loss.
+        perceptual_loss = cfg.trainer.perceptual_loss
+        self._assign_criteria('Perceptual',
+                              PerceptualLoss(
+                                  network=perceptual_loss.mode,
+                                  layers=perceptual_loss.layers,
+                                  weights=perceptual_loss.weights,
+                                  num_scales=getattr(perceptual_loss,
+                                                     'num_scales', 1)),
+                              loss_weight.perceptual)
+
+        # L1 Loss.
+        if getattr(loss_weight, 'L1', 0) > 0:
+            self._assign_criteria('L1', torch.nn.L1Loss(), loss_weight.L1)
+
+        # Whether to add an additional discriminator for specific regions.
+        self.add_dis_cfg = getattr(self.cfg.dis, 'additional_discriminators',
+                                   None)
+        if self.add_dis_cfg is not None:
+            for name in self.add_dis_cfg:
+                add_dis_cfg = self.add_dis_cfg[name]
+                self.weights['GAN_' + name] = add_dis_cfg.loss_weight
+                self.weights['FeatureMatching_' + name] = \
+                    loss_weight.feature_matching
+
+        # Temporal GAN loss.
+        self.num_temporal_scales = get_nested_attr(self.cfg.dis,
+                                                   'temporal.num_scales', 0)
+        for s in range(self.num_temporal_scales):
+            self.weights['GAN_T%d' % s] = loss_weight.temporal_gan
+            self.weights['FeatureMatching_T%d' % s] = \
+                loss_weight.feature_matching
+
+        # Flow loss. It consists of three parts: L1 loss compared to GT,
+        # warping loss when used to warp images, and loss on the occlusion mask.
+        self.use_flow = hasattr(cfg.gen, 'flow')
+        if self.use_flow:
+            self.criteria['Flow'] = FlowLoss(cfg)
+            self.weights['Flow'] = self.weights['Flow_L1'] = \
+                self.weights['Flow_Warp'] = \
+                self.weights['Flow_Mask'] = loss_weight.flow
+
+        # Other custom losses.
+        self._define_custom_losses()
+
+    def _define_custom_losses(self):
+        r"""All other custom losses are defined here."""
+        pass
+
+    def _start_of_epoch(self, current_epoch):
+        r"""Things to do before an epoch. When current_epoch is smaller than
+        $(single_frame_epoch), we only train a single frame and the generator is
+        just an image generator. After that, we start doing temporal training
+        and train multiple frames. We will double the number of training frames
+        every $(num_epochs_temporal_step) epochs.
+
+        Args:
+            current_epoch (int): Current number of epoch.
+        """
+        cfg = self.cfg
+        # Only generates one frame at the beginning of training
+        if current_epoch < cfg.single_frame_epoch:
+            self.train_dataset.sequence_length = 1
+        # Then add the temporal network to generator, and train multiple frames.
+        elif current_epoch == cfg.single_frame_epoch:
+            self.init_temporal_network()
+
+        # Double the length of training sequence every few epochs.
+        temp_epoch = current_epoch - cfg.single_frame_epoch
+        if temp_epoch > 0:
+            sequence_length = \
+                cfg.data.train.initial_sequence_length * \
+                (2 ** (temp_epoch // cfg.num_epochs_temporal_step))
+            sequence_length = min(sequence_length, self.sequence_length_max)
+            if sequence_length > self.sequence_length:
+                self.sequence_length = sequence_length
+                self.train_dataset.set_sequence_length(sequence_length)
+                print('------- Updating sequence length to %d -------' %
+                      sequence_length)
+
+    def init_temporal_network(self):
+        r"""Initialize temporal training when beginning to train multiple
+        frames. Set the sequence length to $(initial_sequence_length).
+        """
+        self.tensorboard_init = False
+        # Update training sequence length.
+        self.sequence_length = self.cfg.data.train.initial_sequence_length
+        if not self.is_inference:
+            self.train_dataset.set_sequence_length(self.sequence_length)
+            print('------ Now start training %d frames -------' %
+                  self.sequence_length)
+
+    def _start_of_iteration(self, data, current_iteration):
+        r"""Things to do before an iteration.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_iteration (int): Current number of iteration.
+        """
+        data = self.pre_process(data)
+        return to_cuda(data)
+
+    def pre_process(self, data):
+        r"""Do any data pre-processing here.
+
+        Args:
+            data (dict): Data used for the current iteration.
+        """
+        data_cfg = self.cfg.data
+        if hasattr(data_cfg, 'for_pose_dataset') and \
+                ('pose_maps-densepose' in data_cfg.input_labels):
+            pose_cfg = data_cfg.for_pose_dataset
+            data['label'] = pre_process_densepose(pose_cfg, data['label'],
+                                                  self.is_inference)
+        return data
+
+    def post_process(self, data, net_G_output):
+        r"""Do any postprocessing of the data / output here.
+
+        Args:
+            data (dict): Training data at the current iteration.
+            net_G_output (dict): Output of the generator.
+        """
+        return data, net_G_output
+
+    def gen_update(self, data):
+        r"""Update the vid2vid generator. We update in the fashion of
+        dis_update (frame 1), gen_update (frame 1),
+        dis_update (frame 2), gen_update (frame 2), ... in each iteration.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        # Whether to reuse generator output for both gen_update and
+        # dis_update. It saves time but consumes a bit more memory.
+        reuse_gen_output = getattr(self.cfg.trainer, 'reuse_gen_output', True)
+
+        past_frames = [None, None]
+        net_G_output = None
+        data_prev = None
+        for t in range(self.sequence_length):
+            data_t = self.get_data_t(data, net_G_output, data_prev, t)
+            data_prev = data_t
+
+            # Discriminator update.
+            if reuse_gen_output:
+                net_G_output = self.net_G(data_t)
+            else:
+                with torch.no_grad():
+                    net_G_output = self.net_G(data_t)
+            data_t, net_G_output = self.post_process(data_t, net_G_output)
+
+            # Get losses and update D if image generated by network in training.
+            if 'fake_images_source' not in net_G_output:
+                net_G_output['fake_images_source'] = 'in_training'
+            if net_G_output['fake_images_source'] != 'pretrained':
+                net_D_output, _ = self.net_D(data_t, detach(net_G_output), past_frames)
+                self.get_dis_losses(net_D_output)
+
+            # Generator update.
+            if not reuse_gen_output:
+                net_G_output = self.net_G(data_t)
+                data_t, net_G_output = self.post_process(data_t, net_G_output)
+
+            # Get losses and update G if image generated by network in training.
+            if 'fake_images_source' not in net_G_output:
+                net_G_output['fake_images_source'] = 'in_training'
+            if net_G_output['fake_images_source'] != 'pretrained':
+                net_D_output, past_frames = \
+                    self.net_D(data_t, net_G_output, past_frames)
+                self.get_gen_losses(data_t, net_G_output, net_D_output)
+
+            # update average
+            if self.cfg.trainer.model_average_config.enabled:
+                self.net_G.module.update_average()
+
+    def dis_update(self, data):
+        r"""The update is already done in gen_update.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        pass
+
+    def reset(self):
+        r"""Reset the trainer (for inference) at the beginning of a sequence.
+        """
+        # print('Resetting trainer.')
+        self.net_G_output = self.data_prev = None
+        self.t = 0
+
+        self.test_in_model_average_mode = getattr(
+            self, 'test_in_model_average_mode', self.cfg.trainer.model_average_config.enabled)
+        if self.test_in_model_average_mode:
+            net_G_module = self.net_G.module.averaged_model
+        else:
+            net_G_module = self.net_G.module
+        if hasattr(net_G_module, 'reset'):
+            net_G_module.reset()
+
+    def create_sequence_output_dir(self, output_dir, key):
+        r"""Create output subdir for this sequence.
+
+        Args:
+            output_dir (str): Root output dir.
+            key (str): LMDB key which contains sequence name and file name.
+        Returns:
+            output_dir (str): Output subdir for this sequence.
+            seq_name (str): Name of this sequence.
+        """
+        seq_dir = '/'.join(key.split('/')[:-1])
+        output_dir = os.path.join(output_dir, seq_dir)
+        os.makedirs(output_dir, exist_ok=True)
+        seq_name = seq_dir.replace('/', '-')
+        return output_dir, seq_name
+
+    def test(self, test_data_loader, root_output_dir, inference_args):
+        r"""Run inference on all sequences.
+
+        Args:
+            test_data_loader (object): Test data loader.
+            root_output_dir (str): Location to dump outputs.
+            inference_args (optional): Optional args.
+        """
+
+        # Go over all sequences.
+        loader = test_data_loader
+        num_inference_sequences = loader.dataset.num_inference_sequences()
+        for sequence_idx in range(num_inference_sequences):
+            loader.dataset.set_inference_sequence_idx(sequence_idx)
+            print('Seq id: %d, Seq length: %d' %
+                  (sequence_idx + 1, len(loader)))
+
+            # Reset model at start of new inference sequence.
+            self.reset()
+            self.sequence_length = len(loader)
+
+            # Go over all frames of this sequence.
+            video = []
+            for idx, data in enumerate(tqdm(loader)):
+                key = data['key']['images'][0][0]
+                filename = key.split('/')[-1]
+
+                # Create output dir for this sequence.
+                if idx == 0:
+                    output_dir, seq_name = \
+                        self.create_sequence_output_dir(root_output_dir, key)
+                    video_path = os.path.join(output_dir, '..', seq_name)
+
+                # Get output and save images.
+                data['img_name'] = filename
+                data = self.start_of_iteration(data, current_iteration=-1)
+                output = self.test_single(data, output_dir, inference_args)
+                video.append(output)
+
+            # Save output as mp4.
+            imageio.mimsave(video_path + '.mp4', video, fps=15)
+
+    def test_single(self, data, output_dir=None, inference_args=None):
+        r"""The inference function. If output_dir exists, also save the
+        output image.
+        Args:
+            data (dict): Training data at the current iteration.
+            output_dir (str): Save image directory.
+            inference_args (obj): Inference args.
+        """
+        if getattr(inference_args, 'finetune', False):
+            if not getattr(self, 'has_finetuned', False):
+                self.finetune(data, inference_args)
+
+        net_G = self.net_G
+        if self.test_in_model_average_mode:
+            net_G = net_G.module.averaged_model
+        net_G.eval()
+
+        data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0)
+        if self.is_inference or self.sequence_length > 1:
+            self.data_prev = data_t
+
+        # Generator forward.
+        with torch.no_grad():
+            self.net_G_output = net_G(data_t)
+
+        if output_dir is None:
+            return self.net_G_output
+
+        save_fake_only = getattr(inference_args, 'save_fake_only', False)
+        if save_fake_only:
+            image_grid = tensor2im(self.net_G_output['fake_images'])[0]
+        else:
+            vis_images = self.get_test_output_images(data)
+            image_grid = np.hstack([np.vstack(im) for im in
+                                    vis_images if im is not None])
+        if 'img_name' in data:
+            save_name = data['img_name'].split('.')[0] + '.jpg'
+        else:
+            save_name = '%04d.jpg' % self.t
+        output_filename = os.path.join(output_dir, save_name)
+        os.makedirs(output_dir, exist_ok=True)
+        imageio.imwrite(output_filename, image_grid)
+        self.t += 1
+
+        return image_grid
+
+    def get_test_output_images(self, data):
+        r"""Get the visualization output of test function.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        vis_images = [
+            self.visualize_label(data['label'][:, -1]),
+            tensor2im(data['images'][:, -1]),
+            tensor2im(self.net_G_output['fake_images']),
+        ]
+        return vis_images
+
+    def gen_frames(self, data, use_model_average=False):
+        r"""Generate a sequence of frames given a sequence of data.
+
+        Args:
+            data (dict): Training data at the current iteration.
+            use_model_average (bool): Whether to use model average
+                for update or not.
+        """
+        net_G_output = None  # Previous generator output.
+        data_prev = None  # Previous data.
+        if use_model_average:
+            net_G = self.net_G.module.averaged_model
+        else:
+            net_G = self.net_G
+
+        # Iterate through the length of sequence.
+        all_info = {'inputs': [], 'outputs': []}
+        for t in range(self.sequence_length):
+            # Get the data at the current time frame.
+            data_t = self.get_data_t(data, net_G_output, data_prev, t)
+            data_prev = data_t
+
+            # Generator forward.
+            with torch.no_grad():
+                net_G_output = net_G(data_t)
+
+            # Do any postprocessing if necessary.
+            data_t, net_G_output = self.post_process(data_t, net_G_output)
+
+            if t == 0:
+                # Get the output at beginning of sequence for visualization.
+                first_net_G_output = net_G_output
+
+            all_info['inputs'].append(data_t)
+            all_info['outputs'].append(net_G_output)
+
+        return first_net_G_output, net_G_output, all_info
+
+    def get_gen_losses(self, data_t, net_G_output, net_D_output):
+        r"""Compute generator losses.
+
+        Args:
+            data_t (dict): Training data at the current time t.
+            net_G_output (dict): Output of the generator.
+            net_D_output (dict): Output of the discriminator.
+        """
+        update_finished = False
+        while not update_finished:
+            with autocast(enabled=self.cfg.trainer.amp_config.enabled):
+                # Individual frame GAN loss and feature matching loss.
+                self.gen_losses['GAN'], self.gen_losses['FeatureMatching'] = \
+                    self.compute_gan_losses(net_D_output['indv'],
+                                            dis_update=False)
+
+                # Perceptual loss.
+                self.gen_losses['Perceptual'] = self.criteria['Perceptual'](
+                    net_G_output['fake_images'], data_t['image'])
+
+                # L1 loss.
+                if getattr(self.cfg.trainer.loss_weight, 'L1', 0) > 0:
+                    self.gen_losses['L1'] = self.criteria['L1'](
+                        net_G_output['fake_images'], data_t['image'])
+
+                # Raw (hallucinated) output image losses (GAN and perceptual).
+                if 'raw' in net_D_output:
+                    raw_GAN_losses = self.compute_gan_losses(
+                        net_D_output['raw'], dis_update=False
+                    )
+                    fg_mask = get_fg_mask(data_t['label'], self.has_fg)
+                    raw_perceptual_loss = self.criteria['Perceptual'](
+                        net_G_output['fake_raw_images'] * fg_mask,
+                        data_t['image'] * fg_mask)
+                    self.gen_losses['GAN'] += raw_GAN_losses[0]
+                    self.gen_losses['FeatureMatching'] += raw_GAN_losses[1]
+                    self.gen_losses['Perceptual'] += raw_perceptual_loss
+
+                # Additional discriminator losses.
+                if self.add_dis_cfg is not None:
+                    for name in self.add_dis_cfg:
+                        (self.gen_losses['GAN_' + name],
+                         self.gen_losses['FeatureMatching_' + name]) = \
+                            self.compute_gan_losses(net_D_output[name],
+                                                    dis_update=False)
+
+                # Flow and mask loss.
+                if self.use_flow:
+                    (self.gen_losses['Flow_L1'], self.gen_losses['Flow_Warp'],
+                     self.gen_losses['Flow_Mask']) = self.criteria['Flow'](
+                        data_t, net_G_output, self.current_epoch)
+
+                # Temporal GAN loss and feature matching loss.
+                if self.cfg.trainer.loss_weight.temporal_gan > 0:
+                    if self.sequence_length > 1:
+                        for s in range(self.num_temporal_scales):
+                            loss_GAN, loss_FM = self.compute_gan_losses(
+                                net_D_output['temporal_%d' % s],
+                                dis_update=False
+                            )
+                            self.gen_losses['GAN_T%d' % s] = loss_GAN
+                            self.gen_losses['FeatureMatching_T%d' % s] = loss_FM
+
+                # Other custom losses.
+                self._get_custom_gen_losses(data_t, net_G_output, net_D_output)
+
+                # Sum all losses together.
+                total_loss = self.Tensor(1).fill_(0)
+                for key in self.gen_losses:
+                    if key != 'total':
+                        total_loss += self.gen_losses[key] * self.weights[key]
+                self.gen_losses['total'] = total_loss
+
+            # Zero-grad and backpropagate the loss.
+            self.opt_G.zero_grad(set_to_none=True)
+            self.scaler_G.scale(total_loss).backward()
+
+            # Optionally clip gradient norm.
+            if hasattr(self.cfg.gen_opt, 'clip_grad_norm'):
+                self.scaler_G.unscale_(self.opt_G)
+                total_norm = torch.nn.utils.clip_grad_norm_(
+                    self.net_G_module.parameters(),
+                    self.cfg.gen_opt.clip_grad_norm
+                )
+                if torch.isfinite(total_norm) and \
+                        total_norm > self.cfg.gen_opt.clip_grad_norm:
+                    print(f"Gradient norm of the generator ({total_norm}) "
+                          f"too large, clipping it to "
+                          f"{self.cfg.gen_opt.clip_grad_norm}.")
+
+            # Perform an optimizer step.
+            self.scaler_G.step(self.opt_G)
+            self.scaler_G.update()
+            # Whether the step above was skipped.
+            if self.last_step_count_G == self.opt_G._step_count:
+                print("Generator overflowed!")
+            else:
+                self.last_step_count_G = self.opt_G._step_count
+                update_finished = True
+
+    def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output):
+        r"""All other custom generator losses go here.
+
+        Args:
+            data_t (dict): Training data at the current time t.
+            net_G_output (dict): Output of the generator.
+            net_D_output (dict): Output of the discriminator.
+        """
+        pass
+
+    def get_dis_losses(self, net_D_output):
+        r"""Compute discriminator losses.
+
+        Args:
+            net_D_output (dict): Output of the discriminator.
+        """
+        update_finished = False
+        while not update_finished:
+            with autocast(enabled=self.cfg.trainer.amp_config.enabled):
+                # Individual frame GAN loss.
+                self.dis_losses['GAN'] = self.compute_gan_losses(
+                    net_D_output['indv'], dis_update=True
+                )
+
+                # Raw (hallucinated) output image GAN loss.
+                if 'raw' in net_D_output:
+                    raw_loss = self.compute_gan_losses(net_D_output['raw'],
+                                                       dis_update=True)
+                    self.dis_losses['GAN'] += raw_loss
+
+                # Additional GAN loss.
+                if self.add_dis_cfg is not None:
+                    for name in self.add_dis_cfg:
+                        self.dis_losses['GAN_' + name] = \
+                            self.compute_gan_losses(net_D_output[name],
+                                                    dis_update=True)
+
+                # Temporal GAN loss.
+                if self.cfg.trainer.loss_weight.temporal_gan > 0:
+                    if self.sequence_length > 1:
+                        for s in range(self.num_temporal_scales):
+                            self.dis_losses['GAN_T%d' % s] = \
+                                self.compute_gan_losses(
+                                    net_D_output['temporal_%d' % s],
+                                    dis_update=True
+                                )
+
+                # Other custom losses.
+                self._get_custom_dis_losses(net_D_output)
+
+                # Sum all losses together.
+                total_loss = self.Tensor(1).fill_(0)
+                for key in self.dis_losses:
+                    if key != 'total':
+                        total_loss += self.dis_losses[key] * self.weights[key]
+                self.dis_losses['total'] = total_loss
+
+            # Zero-grad and backpropagate the loss.
+            self.opt_D.zero_grad(set_to_none=True)
+            self._time_before_backward()
+            self.scaler_D.scale(total_loss).backward()
+
+            # Optionally clip gradient norm.
+            if hasattr(self.cfg.dis_opt, 'clip_grad_norm'):
+                self.scaler_D.unscale_(self.opt_D)
+                total_norm = torch.nn.utils.clip_grad_norm_(
+                    self.net_D.parameters(), self.cfg.dis_opt.clip_grad_norm
+                )
+                if torch.isfinite(total_norm) and \
+                        total_norm > self.cfg.dis_opt.clip_grad_norm:
+                    print(f"Gradient norm of the discriminator ({total_norm}) "
+                          f"too large, clipping it to "
+                          f"{self.cfg.dis_opt.clip_grad_norm}.")
+
+            # Perform an optimizer step.
+            self._time_before_step()
+            self.scaler_D.step(self.opt_D)
+            self.scaler_D.update()
+            # Whether the step above was skipped.
+            if self.last_step_count_D == self.opt_D._step_count:
+                print("Discriminator overflowed!")
+            else:
+                self.last_step_count_D = self.opt_D._step_count
+                update_finished = True
+
+    def _get_custom_dis_losses(self, net_D_output):
+        r"""All other custom losses go here.
+
+        Args:
+            net_D_output (dict): Output of the discriminator.
+        """
+        pass
+
+    def compute_gan_losses(self, net_D_output, dis_update):
+        r"""Compute GAN loss and feature matching loss.
+
+        Args:
+            net_D_output (dict): Output of the discriminator.
+            dis_update (bool): Whether to update discriminator.
+        """
+        if net_D_output['pred_fake'] is None:
+            return self.Tensor(1).fill_(0) if dis_update else [
+                self.Tensor(1).fill_(0), self.Tensor(1).fill_(0)]
+        if dis_update:
+            # Get the GAN loss for real/fake outputs.
+            GAN_loss = \
+                self.criteria['GAN'](net_D_output['pred_fake']['output'], False,
+                                     dis_update=True) + \
+                self.criteria['GAN'](net_D_output['pred_real']['output'], True,
+                                     dis_update=True)
+            return GAN_loss
+        else:
+            # Get the GAN loss and feature matching loss for fake output.
+            GAN_loss = self.criteria['GAN'](
+                net_D_output['pred_fake']['output'], True, dis_update=False)
+
+            FM_loss = self.criteria['FeatureMatching'](
+                net_D_output['pred_fake']['features'],
+                net_D_output['pred_real']['features'])
+            return GAN_loss, FM_loss
+
+    def get_data_t(self, data, net_G_output, data_prev, t):
+        r"""Get data at current time frame given the sequence of data.
+
+        Args:
+            data (dict): Training data for current iteration.
+            net_G_output (dict): Output of the generator (for previous frame).
+            data_prev (dict): Data for previous frame.
+            t (int): Current time.
+        """
+        label = data['label'][:, t]
+        image = data['images'][:, t]
+
+        if data_prev is not None:
+            # Concat previous labels/fake images to the ones before.
+            num_frames_G = self.cfg.data.num_frames_G
+            prev_labels = concat_frames(data_prev['prev_labels'],
+                                        data_prev['label'], num_frames_G - 1)
+            prev_images = concat_frames(
+                data_prev['prev_images'],
+                net_G_output['fake_images'].detach(), num_frames_G - 1)
+        else:
+            prev_labels = prev_images = None
+
+        data_t = dict()
+        data_t['label'] = label
+        data_t['image'] = image
+        data_t['prev_labels'] = prev_labels
+        data_t['prev_images'] = prev_images
+        data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None
+        return data_t
+
+    def _end_of_iteration(self, data, current_epoch, current_iteration):
+        r"""Print the errors to console."""
+        if not torch.distributed.is_initialized():
+            if current_iteration % self.cfg.logging_iter == 0:
+                message = '(epoch: %d, iters: %d) ' % (current_epoch,
+                                                       current_iteration)
+                for k, v in self.gen_losses.items():
+                    if k != 'total':
+                        message += '%s: %.3f,  ' % (k, v)
+                message += '\n'
+                for k, v in self.dis_losses.items():
+                    if k != 'total':
+                        message += '%s: %.3f,  ' % (k, v)
+                print(message)
+
+    def write_metrics(self):
+        r"""If moving average model presents, we have two meters one for
+        regular FID and one for average FID. If no moving average model,
+        we just report average FID.
+        """
+        if self.cfg.trainer.model_average_config.enabled:
+            regular_fid, average_fid = self._compute_fid()
+            if regular_fid is None or average_fid is None:
+                return
+            metric_dict = {'FID/average': average_fid, 'FID/regular': regular_fid}
+            self._write_to_meters(metric_dict, self.metric_meters, reduce=False)
+        else:
+            regular_fid = self._compute_fid()
+            if regular_fid is None:
+                return
+            metric_dict = {'FID/regular': regular_fid}
+            self._write_to_meters(metric_dict, self.metric_meters, reduce=False)
+        self._flush_meters(self.metric_meters)
+
+    def _compute_fid(self):
+        r"""Compute FID values."""
+        self.net_G.eval()
+        self.net_G_output = None
+        # Due to complicated video evaluation procedure we are using, we will
+        # pass the trainer to the evaluation code instead of the
+        # generator network.
+        # net_G_for_evaluation = self.net_G
+        trainer = self
+        self.test_in_model_average_mode = False
+        regular_fid_path = self._get_save_path('regular_fid', 'npy')
+        few_shot = True if 'few_shot' in self.cfg.data.type else False
+        regular_fid_value = compute_fid(regular_fid_path, self.val_data_loader,
+                                        trainer,
+                                        sample_size=self.sample_size,
+                                        is_video=True, few_shot_video=few_shot)
+        print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format(
+            self.current_epoch, self.current_iteration, regular_fid_value))
+        if self.cfg.trainer.model_average_config.enabled:
+            # Due to complicated video evaluation procedure we are using,
+            # we will pass the trainer to the evaluation code instead of the
+            # generator network.
+            # avg_net_G_for_evaluation = self.net_G.module.averaged_model
+            trainer_avg_mode = self
+            self.test_in_model_average_mode = True
+            # The above flag will be reset after computing FID.
+            fid_path = self._get_save_path('average_fid', 'npy')
+            few_shot = True if 'few_shot' in self.cfg.data.type else False
+            fid_value = compute_fid(fid_path, self.val_data_loader,
+                                    trainer_avg_mode,
+                                    sample_size=self.sample_size,
+                                    is_video=True, few_shot_video=few_shot)
+            print('Epoch {:05}, Iteration {:09}, Average FID {}'.format(
+                self.current_epoch, self.current_iteration, fid_value))
+            self.net_G.float()
+            return regular_fid_value, fid_value
+        else:
+            self.net_G.float()
+            return regular_fid_value
+
+    def visualize_label(self, label):
+        r"""Visualize the input label when saving to image.
+
+        Args:
+            label (tensor): Input label tensor.
+        """
+        cfgdata = self.cfg.data
+        if hasattr(cfgdata, 'for_pose_dataset'):
+            label = tensor2pose(self.cfg, label)
+        elif hasattr(cfgdata, 'input_labels') and \
+                'seg_maps' in cfgdata.input_labels:
+            for input_type in cfgdata.input_types:
+                if 'seg_maps' in input_type:
+                    num_labels = cfgdata.one_hot_num_classes.seg_maps
+            label = tensor2label(label, num_labels)
+        elif getattr(cfgdata, 'label_channels', 1) > 3:
+            label = tensor2im(label.sum(1, keepdim=True))
+        else:
+            label = tensor2im(label)
+        return label
+
+    def save_image(self, path, data):
+        r"""Save the output images to path.
+        Note when the generate_raw_output is FALSE. Then,
+        first_net_G_output['fake_raw_images'] is None and will not be displayed.
+        In model average mode, we will plot the flow visualization twice.
+        Args:
+            path (str): Save path.
+            data (dict): Training data for current iteration.
+        """
+        self.net_G.eval()
+        if self.cfg.trainer.model_average_config.enabled:
+            self.net_G.module.averaged_model.eval()
+        self.net_G_output = None
+        with torch.no_grad():
+            first_net_G_output, net_G_output, all_info = self.gen_frames(data)
+            if self.cfg.trainer.model_average_config.enabled:
+                first_net_G_output_avg, net_G_output_avg, _ = self.gen_frames(
+                    data, use_model_average=True)
+
+        # Visualize labels.
+        label_lengths = self.train_data_loader.dataset.get_label_lengths()
+        labels = split_labels(data['label'], label_lengths)
+        vis_labels_start, vis_labels_end = [], []
+        for key, value in labels.items():
+            if key == 'seg_maps':
+                vis_labels_start.append(self.visualize_label(value[:, -1]))
+                vis_labels_end.append(self.visualize_label(value[:, 0]))
+            else:
+                vis_labels_start.append(tensor2im(value[:, -1]))
+                vis_labels_end.append(tensor2im(value[:, 0]))
+
+        if is_master():
+            vis_images = [
+                *vis_labels_start,
+                tensor2im(data['images'][:, -1]),
+                tensor2im(net_G_output['fake_images']),
+                tensor2im(net_G_output['fake_raw_images'])]
+            if self.cfg.trainer.model_average_config.enabled:
+                vis_images += [
+                    tensor2im(net_G_output_avg['fake_images']),
+                    tensor2im(net_G_output_avg['fake_raw_images'])]
+
+            if self.sequence_length > 1:
+                vis_images_first = [
+                    *vis_labels_end,
+                    tensor2im(data['images'][:, 0]),
+                    tensor2im(first_net_G_output['fake_images']),
+                    tensor2im(first_net_G_output['fake_raw_images'])
+                ]
+                if self.cfg.trainer.model_average_config.enabled:
+                    vis_images_first += [
+                        tensor2im(first_net_G_output_avg['fake_images']),
+                        tensor2im(first_net_G_output_avg['fake_raw_images'])]
+
+                if self.use_flow:
+                    flow_gt, conf_gt = self.criteria['Flow'].flowNet(
+                        data['images'][:, -1], data['images'][:, -2])
+                    warped_image_gt = resample(data['images'][:, -1], flow_gt)
+                    vis_images_first += [
+                        tensor2flow(flow_gt),
+                        tensor2im(conf_gt, normalize=False),
+                        tensor2im(warped_image_gt),
+                    ]
+                    vis_images += [
+                        tensor2flow(net_G_output['fake_flow_maps']),
+                        tensor2im(net_G_output['fake_occlusion_masks'],
+                                  normalize=False),
+                        tensor2im(net_G_output['warped_images']),
+                    ]
+                    if self.cfg.trainer.model_average_config.enabled:
+                        vis_images_first += [
+                            tensor2flow(flow_gt),
+                            tensor2im(conf_gt, normalize=False),
+                            tensor2im(warped_image_gt),
+                        ]
+                        vis_images += [
+                            tensor2flow(net_G_output_avg['fake_flow_maps']),
+                            tensor2im(net_G_output_avg['fake_occlusion_masks'],
+                                      normalize=False),
+                            tensor2im(net_G_output_avg['warped_images'])]
+
+                vis_images = [[np.vstack((im_first, im))
+                               for im_first, im in zip(imgs_first, imgs)]
+                              for imgs_first, imgs in zip(vis_images_first,
+                                                          vis_images)
+                              if imgs is not None]
+
+            image_grid = np.hstack([np.vstack(im) for im in
+                                    vis_images if im is not None])
+
+            print('Save output images to {}'.format(path))
+            os.makedirs(os.path.dirname(path), exist_ok=True)
+            imageio.imwrite(path, image_grid)
+
+            # Gather all outputs for dumping into video.
+            if self.sequence_length > 1:
+                output_images = []
+                for item in all_info['outputs']:
+                    output_images.append(tensor2im(item['fake_images'])[0])
+
+                imageio.mimwrite(os.path.splitext(path)[0] + '.mp4',
+                                 output_images, fps=2, macro_block_size=None)
+
+        self.net_G.float()
diff --git a/imaginaire/trainers/wc_vid2vid.py b/imaginaire/trainers/wc_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..78b3007cd1767ab3df3b43524747d858b1d3c454
--- /dev/null
+++ b/imaginaire/trainers/wc_vid2vid.py
@@ -0,0 +1,503 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+import time
+
+import imageio
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from imaginaire.losses import MaskedL1Loss
+from imaginaire.model_utils.fs_vid2vid import concat_frames, resample
+from imaginaire.trainers.vid2vid import Trainer as Vid2VidTrainer
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.misc import split_labels, to_cuda
+from imaginaire.utils.visualization import tensor2flow, tensor2im
+
+
+class Trainer(Vid2VidTrainer):
+    r"""Initialize world consistent vid2vid trainer.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network.
+        net_D (obj): Discriminator network.
+        opt_G (obj): Optimizer for the generator network.
+        opt_D (obj): Optimizer for the discriminator network.
+        sch_G (obj): Scheduler for the generator optimizer.
+        sch_D (obj): Scheduler for the discriminator optimizer.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+    """
+
+    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
+                 train_data_loader, val_data_loader):
+        super(Trainer, self).__init__(cfg, net_G, net_D, opt_G,
+                                      opt_D, sch_G, sch_D,
+                                      train_data_loader, val_data_loader)
+        self.guidance_start_after = getattr(cfg.gen.guidance, 'start_from', 0)
+        self.train_data_loader = train_data_loader
+
+    def _define_custom_losses(self):
+        r"""All other custom losses are defined here."""
+        # Setup the guidance loss.
+        self.criteria['Guidance'] = MaskedL1Loss(normalize_over_valid=True)
+        self.weights['Guidance'] = self.cfg.trainer.loss_weight.guidance
+
+    def start_of_iteration(self, data, current_iteration):
+        r"""Things to do before an iteration.
+
+        Args:
+            data (dict): Data used for the current iteration.
+            current_iteration (int): Current iteration number.
+        """
+        self.net_G_module.reset_renderer(is_flipped_input=data['is_flipped'])
+        # Keep unprojections on cpu to prevent unnecessary transfer.
+        unprojections = data.pop('unprojections')
+        data = to_cuda(data)
+        data['unprojections'] = unprojections
+
+        self.current_iteration = current_iteration
+        if not self.is_inference:
+            self.net_D.train()
+        self.net_G.train()
+        self.start_iteration_time = time.time()
+        return data
+
+    def reset(self):
+        r"""Reset the trainer (for inference) at the beginning of a sequence."""
+        # Inference time.
+        self.net_G_module.reset_renderer(is_flipped_input=False)
+
+        # print('Resetting trainer.')
+        self.net_G_output = self.data_prev = None
+        self.t = 0
+
+        test_in_model_average_mode = getattr(
+            self, 'test_in_model_average_mode', False)
+        if test_in_model_average_mode:
+            if hasattr(self.net_G.module.averaged_model, 'reset'):
+                self.net_G.module.averaged_model.reset()
+        else:
+            if hasattr(self.net_G.module, 'reset'):
+                self.net_G.module.reset()
+
+    def create_sequence_output_dir(self, output_dir, key):
+        r"""Create output subdir for this sequence.
+
+        Args:
+            output_dir (str): Root output dir.
+            key (str): LMDB key which contains sequence name and file name.
+            Returns:
+            output_dir (str): Output subdir for this sequence.
+            seq_name (str): Name of this sequence.
+        """
+        seq_dir = '/'.join(key.split('/')[:-1])
+        output_dir = os.path.join(output_dir, seq_dir)
+        os.makedirs(output_dir, exist_ok=True)
+        os.makedirs(output_dir + '/all', exist_ok=True)
+        os.makedirs(output_dir + '/fake', exist_ok=True)
+        seq_name = seq_dir.replace('/', '-')
+        return output_dir, seq_name
+
+    def test(self, test_data_loader, root_output_dir, inference_args):
+        r"""Run inference on all sequences.
+
+        Args:
+            test_data_loader (object): Test data loader.
+            root_output_dir (str): Location to dump outputs.
+            inference_args (optional): Optional args.
+        """
+
+        # Go over all sequences.
+        loader = test_data_loader
+        num_inference_sequences = loader.dataset.num_inference_sequences()
+        for sequence_idx in range(num_inference_sequences):
+            loader.dataset.set_inference_sequence_idx(sequence_idx)
+            print('Seq id: %d, Seq length: %d' %
+                  (sequence_idx + 1, len(loader)))
+
+            # Reset model at start of new inference sequence.
+            self.reset()
+            self.sequence_length = len(loader)
+
+            # Go over all frames of this sequence.
+            video = []
+            for idx, data in enumerate(tqdm(loader)):
+                key = data['key']['images'][0][0]
+                filename = key.split('/')[-1]
+
+                # Create output dir for this sequence.
+                if idx == 0:
+                    output_dir, seq_name = \
+                        self.create_sequence_output_dir(root_output_dir, key)
+                    video_path = os.path.join(output_dir, '..', seq_name)
+
+                # Get output, and save all vis to all/.
+                data['img_name'] = filename
+                data = to_cuda(data)
+                output = self.test_single(data, output_dir=output_dir + '/all')
+
+                # Dump just the fake image here.
+                fake = tensor2im(output['fake_images'])[0]
+                video.append(fake)
+                imageio.imsave(output_dir + '/fake/%s.jpg' % (filename), fake)
+
+            # Save as mp4 and gif.
+            imageio.mimsave(video_path + '.mp4', video, fps=15)
+
+    def test_single(self, data, output_dir=None, save_fake_only=False):
+        r"""The inference function. If output_dir exists, also save the
+        output image.
+
+        Args:
+            data (dict): Training data at the current iteration.
+            output_dir (str): Save image directory.
+            save_fake_only (bool): Only save the fake output image.
+        """
+        if self.is_inference and self.cfg.trainer.model_average_config.enabled:
+            test_in_model_average_mode = True
+        else:
+            test_in_model_average_mode = getattr(
+                self, 'test_in_model_average_mode', False)
+        data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0)
+        if self.sequence_length > 1:
+            self.data_prev = data_t
+
+        # Generator forward.
+        # Reset renderer if first time step.
+        if self.t == 0:
+            self.net_G_module.reset_renderer(
+                is_flipped_input=data['is_flipped'])
+        with torch.no_grad():
+            if test_in_model_average_mode:
+                net_G = self.net_G.module.averaged_model
+            else:
+                net_G = self.net_G
+            self.net_G_output = net_G(data_t)
+
+        if output_dir is not None:
+            if save_fake_only:
+                image_grid = tensor2im(self.net_G_output['fake_images'])[0]
+            else:
+                vis_images = self.get_test_output_images(data)
+                image_grid = np.hstack([np.vstack(im) for im in
+                                        vis_images if im is not None])
+            if 'img_name' in data:
+                save_name = data['img_name'].split('.')[0] + '.jpg'
+            else:
+                save_name = '%04d.jpg' % self.t
+            output_filename = os.path.join(output_dir, save_name)
+            os.makedirs(output_dir, exist_ok=True)
+            imageio.imwrite(output_filename, image_grid)
+            self.t += 1
+
+        return self.net_G_output
+
+    def get_test_output_images(self, data):
+        r"""Get the visualization output of test function.
+
+        Args:
+            data (dict): Training data at the current iteration.
+        """
+        # Visualize labels.
+        label_lengths = self.val_data_loader.dataset.get_label_lengths()
+        labels = split_labels(data['label'], label_lengths)
+        vis_labels = []
+        for key, value in labels.items():
+            if key == 'seg_maps':
+                vis_labels.append(self.visualize_label(value[:, -1]))
+            else:
+                vis_labels.append(tensor2im(value[:, -1]))
+
+        # Get gt image.
+        im = tensor2im(data['images'][:, -1])
+
+        # Get guidance image and masks.
+        if self.net_G_output['guidance_images_and_masks'] is not None:
+            guidance_image = tensor2im(
+                self.net_G_output['guidance_images_and_masks'][:, :3])
+            guidance_mask = tensor2im(
+                self.net_G_output['guidance_images_and_masks'][:, 3:4],
+                normalize=False)
+        else:
+            guidance_image = [np.zeros_like(item) for item in im]
+            guidance_mask = [np.zeros_like(item) for item in im]
+
+        # Create output.
+        vis_images = [
+            *vis_labels,
+            im,
+            guidance_image, guidance_mask,
+            tensor2im(self.net_G_output['fake_images']),
+        ]
+        return vis_images
+
+    def gen_frames(self, data, use_model_average=False):
+        r"""Generate a sequence of frames given a sequence of data.
+
+        Args:
+            data (dict): Training data at the current iteration.
+            use_model_average (bool): Whether to use model average
+                for update or not.
+        """
+        net_G_output = None  # Previous generator output.
+        data_prev = None  # Previous data.
+        if use_model_average:
+            net_G = self.net_G.module.averaged_model
+        else:
+            net_G = self.net_G
+
+        # Iterate through the length of sequence.
+        self.net_G_module.reset_renderer(is_flipped_input=data['is_flipped'])
+
+        all_info = {'inputs': [], 'outputs': []}
+        for t in range(self.sequence_length):
+            # Get the data at the current time frame.
+            data_t = self.get_data_t(data, net_G_output, data_prev, t)
+            data_prev = data_t
+
+            # Generator forward.
+            with torch.no_grad():
+                net_G_output = net_G(data_t)
+
+            # Do any postprocessing if necessary.
+            data_t, net_G_output = self.post_process(data_t, net_G_output)
+
+            if t == 0:
+                # Get the output at beginning of sequence for visualization.
+                first_net_G_output = net_G_output
+
+            all_info['inputs'].append(data_t)
+            all_info['outputs'].append(net_G_output)
+
+        return first_net_G_output, net_G_output, all_info
+
+    def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output):
+        r"""All other custom generator losses go here.
+
+        Args:
+            data_t (dict): Training data at the current time t.
+            net_G_output (dict): Output of the generator.
+            net_D_output (dict): Output of the discriminator.
+        """
+        # Compute guidance loss.
+        if net_G_output['guidance_images_and_masks'] is not None:
+            guidance_image = net_G_output['guidance_images_and_masks'][:, :3]
+            guidance_mask = net_G_output['guidance_images_and_masks'][:, 3:]
+            self.gen_losses['Guidance'] = self.criteria['Guidance'](
+                net_G_output['fake_images'], guidance_image, guidance_mask)
+        else:
+            self.gen_losses['Guidance'] = self.Tensor(1).fill_(0)
+
+    def get_data_t(self, data, net_G_output, data_prev, t):
+        r"""Get data at current time frame given the sequence of data.
+
+        Args:
+            data (dict): Training data for current iteration.
+            net_G_output (dict): Output of the generator (for previous frame).
+            data_prev (dict): Data for previous frame.
+            t (int): Current time.
+        """
+        label = data['label'][:, t]
+        image = data['images'][:, t]
+
+        # Get keypoint mapping.
+        unprojection = None
+        if t >= self.guidance_start_after:
+            if 'unprojections' in data:
+                try:
+                    # Remove unwanted padding.
+                    unprojection = {}
+                    for key, value in data['unprojections'].items():
+                        value = value[0, t].cpu().numpy()
+                        length = value[-1][0]
+                        unprojection[key] = value[:length]
+                except:  # noqa
+                    pass
+
+        if data_prev is not None:
+            # Concat previous labels/fake images to the ones before.
+            num_frames_G = self.cfg.data.num_frames_G
+            prev_labels = concat_frames(data_prev['prev_labels'],
+                                        data_prev['label'], num_frames_G - 1)
+            prev_images = concat_frames(
+                data_prev['prev_images'],
+                net_G_output['fake_images'].detach(), num_frames_G - 1)
+        else:
+            prev_labels = prev_images = None
+
+        data_t = dict()
+        data_t['label'] = label
+        data_t['image'] = image
+        data_t['prev_labels'] = prev_labels
+        data_t['prev_images'] = prev_images
+        data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None
+        data_t['unprojection'] = unprojection
+        return data_t
+
+    def save_image(self, path, data):
+        r"""Save the output images to path.
+        Note when the generate_raw_output is FALSE. Then,
+        first_net_G_output['fake_raw_images'] is None and will not be displayed.
+        In model average mode, we will plot the flow visualization twice.
+
+        Args:
+            path (str): Save path.
+            data (dict): Training data for current iteration.
+        """
+        self.net_G.eval()
+        if self.cfg.trainer.model_average_config.enabled:
+            self.net_G.module.averaged_model.eval()
+        self.net_G_output = None
+        with torch.no_grad():
+            first_net_G_output, net_G_output, all_info = self.gen_frames(data)
+            if self.cfg.trainer.model_average_config.enabled:
+                first_net_G_output_avg, net_G_output_avg = self.gen_frames(
+                    data, use_model_average=True)
+
+        # Visualize labels.
+        label_lengths = self.train_data_loader.dataset.get_label_lengths()
+        labels = split_labels(data['label'], label_lengths)
+        vis_labels_start, vis_labels_end = [], []
+        for key, value in labels.items():
+            if 'seg_maps' in key:
+                vis_labels_start.append(self.visualize_label(value[:, -1]))
+                vis_labels_end.append(self.visualize_label(value[:, 0]))
+            else:
+                normalize = self.train_data_loader.dataset.normalize[key]
+                vis_labels_start.append(
+                    tensor2im(value[:, -1], normalize=normalize))
+                vis_labels_end.append(
+                    tensor2im(value[:, 0], normalize=normalize))
+
+        if is_master():
+            vis_images = [
+                *vis_labels_start,
+                tensor2im(data['images'][:, -1]),
+                tensor2im(net_G_output['fake_images']),
+                tensor2im(net_G_output['fake_raw_images'])]
+            if self.cfg.trainer.model_average_config.enabled:
+                vis_images += [
+                    tensor2im(net_G_output_avg['fake_images']),
+                    tensor2im(net_G_output_avg['fake_raw_images'])]
+
+            if self.sequence_length > 1:
+                if net_G_output['guidance_images_and_masks'] is not None:
+                    guidance_image = tensor2im(
+                        net_G_output['guidance_images_and_masks'][:, :3])
+                    guidance_mask = tensor2im(
+                        net_G_output['guidance_images_and_masks'][:, 3:4],
+                        normalize=False)
+                else:
+                    im = tensor2im(data['images'][:, -1])
+                    guidance_image = [np.zeros_like(item) for item in im]
+                    guidance_mask = [np.zeros_like(item) for item in im]
+                vis_images += [guidance_image, guidance_mask]
+
+                vis_images_first = [
+                    *vis_labels_end,
+                    tensor2im(data['images'][:, 0]),
+                    tensor2im(first_net_G_output['fake_images']),
+                    tensor2im(first_net_G_output['fake_raw_images']),
+                    [np.zeros_like(item) for item in guidance_image],
+                    [np.zeros_like(item) for item in guidance_mask]
+                ]
+                if self.cfg.trainer.model_average_config.enabled:
+                    vis_images_first += [
+                        tensor2im(first_net_G_output_avg['fake_images']),
+                        tensor2im(first_net_G_output_avg['fake_raw_images'])]
+
+                if self.use_flow:
+                    flow_gt, conf_gt = self.criteria['Flow'].flowNet(
+                        data['images'][:, -1], data['images'][:, -2])
+                    warped_image_gt = resample(data['images'][:, -1], flow_gt)
+                    vis_images_first += [
+                        tensor2flow(flow_gt),
+                        tensor2im(conf_gt, normalize=False),
+                        tensor2im(warped_image_gt),
+                    ]
+                    vis_images += [
+                        tensor2flow(net_G_output['fake_flow_maps']),
+                        tensor2im(net_G_output['fake_occlusion_masks'],
+                                  normalize=False),
+                        tensor2im(net_G_output['warped_images']),
+                    ]
+                    if self.cfg.trainer.model_average_config.enabled:
+                        vis_images_first += [
+                            tensor2flow(flow_gt),
+                            tensor2im(conf_gt, normalize=False),
+                            tensor2im(warped_image_gt),
+                        ]
+                        vis_images += [
+                            tensor2flow(net_G_output_avg['fake_flow_maps']),
+                            tensor2im(net_G_output_avg['fake_occlusion_masks'],
+                                      normalize=False),
+                            tensor2im(net_G_output_avg['warped_images'])]
+
+                vis_images = [[np.vstack((im_first, im))
+                               for im_first, im in zip(imgs_first, imgs)]
+                              for imgs_first, imgs in zip(vis_images_first,
+                                                          vis_images)
+                              if imgs is not None]
+
+            image_grid = np.hstack([np.vstack(im) for im in
+                                    vis_images if im is not None])
+
+            print('Save output images to {}'.format(path))
+            os.makedirs(os.path.dirname(path), exist_ok=True)
+            imageio.imwrite(path, image_grid)
+
+            # Gather all inputs and outputs for dumping into video.
+            if self.sequence_length > 1:
+                input_images, output_images, output_guidance = [], [], []
+                for item in all_info['inputs']:
+                    input_images.append(tensor2im(item['image'])[0])
+                for item in all_info['outputs']:
+                    output_images.append(tensor2im(item['fake_images'])[0])
+                    if item['guidance_images_and_masks'] is not None:
+                        output_guidance.append(tensor2im(
+                            item['guidance_images_and_masks'][:, :3])[0])
+                    else:
+                        output_guidance.append(np.zeros_like(output_images[-1]))
+
+                imageio.mimwrite(os.path.splitext(path)[0] + '.mp4',
+                                 output_images, fps=2, macro_block_size=None)
+                imageio.mimwrite(os.path.splitext(path)[0] + '_guidance.mp4',
+                                 output_guidance, fps=2, macro_block_size=None)
+
+            # for idx, item in enumerate(output_guidance):
+            #     imageio.imwrite(os.path.splitext(
+            #         path)[0] + '_guidance_%d.jpg' % (idx), item)
+            # for idx, item in enumerate(input_images):
+            #     imageio.imwrite(os.path.splitext(
+            #         path)[0] + '_input_%d.jpg' % (idx), item)
+
+        self.net_G.float()
+
+    def _compute_fid(self):
+        r"""Compute fid. Ignore for faster training."""
+        return None
+
+    def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True):
+        r"""Save network weights, optimizer parameters, scheduler parameters
+        in the checkpoint.
+
+        Args:
+            cfg (obj): Global configuration.
+            checkpoint_path (str): Path to the checkpoint.
+        """
+        # Create the single image model.
+        if self.train_data_loader is None:
+            load_single_image_model_weights = False
+        else:
+            load_single_image_model_weights = True
+        self.net_G.module._init_single_image_model(
+            load_weights=load_single_image_model_weights)
+
+        # Call the original super function.
+        return super().load_checkpoint(cfg, checkpoint_path, resume, load_sch)
diff --git a/imaginaire/utils/__init__.py b/imaginaire/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/utils/__pycache__/__init__.cpython-38.pyc b/imaginaire/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0671333f42ff39838dabfd22f45a229e98c2033f
Binary files /dev/null and b/imaginaire/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/utils/__pycache__/distributed.cpython-38.pyc b/imaginaire/utils/__pycache__/distributed.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14d9ab6fd67fee0fbc2f552b61d8aea46723bc79
Binary files /dev/null and b/imaginaire/utils/__pycache__/distributed.cpython-38.pyc differ
diff --git a/imaginaire/utils/__pycache__/init_weight.cpython-38.pyc b/imaginaire/utils/__pycache__/init_weight.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75e710fcc11586e2e8df484362244f5ca9367701
Binary files /dev/null and b/imaginaire/utils/__pycache__/init_weight.cpython-38.pyc differ
diff --git a/imaginaire/utils/__pycache__/misc.cpython-38.pyc b/imaginaire/utils/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26718e225ce3330bb85d0a6851ee629aff81c6d1
Binary files /dev/null and b/imaginaire/utils/__pycache__/misc.cpython-38.pyc differ
diff --git a/imaginaire/utils/__pycache__/model_average.cpython-38.pyc b/imaginaire/utils/__pycache__/model_average.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03c944016eebac22362c35a46c8fd7bd6327f7b0
Binary files /dev/null and b/imaginaire/utils/__pycache__/model_average.cpython-38.pyc differ
diff --git a/imaginaire/utils/__pycache__/trainer.cpython-38.pyc b/imaginaire/utils/__pycache__/trainer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f699fc92b20ab9dfa0768b4479f377770b31447b
Binary files /dev/null and b/imaginaire/utils/__pycache__/trainer.cpython-38.pyc differ
diff --git a/imaginaire/utils/cudnn.py b/imaginaire/utils/cudnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7a5cc3b5607c56e997a6c38c184e4b3f4e302f8
--- /dev/null
+++ b/imaginaire/utils/cudnn.py
@@ -0,0 +1,22 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.backends.cudnn as cudnn
+
+from imaginaire.utils.distributed import master_only_print as print
+
+
+def init_cudnn(deterministic, benchmark):
+    r"""Initialize the cudnn module. The two things to consider is whether to
+    use cudnn benchmark and whether to use cudnn deterministic. If cudnn
+    benchmark is set, then the cudnn deterministic is automatically false.
+
+    Args:
+        deterministic (bool): Whether to use cudnn deterministic.
+        benchmark (bool): Whether to use cudnn benchmark.
+    """
+    cudnn.deterministic = deterministic
+    cudnn.benchmark = benchmark
+    print('cudnn benchmark: {}'.format(benchmark))
+    print('cudnn deterministic: {}'.format(deterministic))
diff --git a/imaginaire/utils/data.py b/imaginaire/utils/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..22268c955ff7df4e5933e8ce2fb0b38c9c0e2f4a
--- /dev/null
+++ b/imaginaire/utils/data.py
@@ -0,0 +1,612 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa: E712
+"""Utils for handling datasets."""
+
+import time
+import numpy as np
+from PIL import Image
+
+# https://github.com/albumentations-team/albumentations#comments
+import cv2
+# from imaginaire.utils.distributed import master_only_print as print
+import albumentations as alb  # noqa nopep8
+
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+IMG_EXTENSIONS = ('jpg', 'jpeg', 'png', 'ppm', 'bmp',
+                  'pgm', 'tif', 'tiff', 'webp',
+                  'JPG', 'JPEG', 'PNG', 'PPM', 'BMP',
+                  'PGM', 'TIF', 'TIFF', 'WEBP')
+HDR_IMG_EXTENSIONS = ('hdr',)
+VIDEO_EXTENSIONS = 'mp4'
+
+
+class Augmentor(object):
+    r"""Handles data augmentation using albumentations library."""
+
+    def __init__(self, aug_list, individual_video_frame_aug_list, image_data_types, is_mask,
+                 keypoint_data_types, interpolator):
+        r"""Initializes augmentation pipeline.
+
+        Args:
+            aug_list (list): List of augmentation operations in sequence.
+            individual_video_frame_aug_list (list): List of augmentation operations in sequence that will be applied
+                to individual frames of videos independently.
+            image_data_types (list): List of keys in expected inputs.
+            is_mask (dict): Whether this data type is discrete masks?
+            keypoint_data_types (list): List of keys which are keypoints.
+        """
+
+        self.aug_list = aug_list
+        self.individual_video_frame_aug_list = individual_video_frame_aug_list
+        self.image_data_types = image_data_types
+        self.is_mask = is_mask
+        self.crop_h, self.crop_w = None, None
+        self.resize_h, self.resize_w = None, None
+        self.resize_smallest_side = None
+        self.max_time_step = 1
+        self.keypoint_data_types = keypoint_data_types
+        self.interpolator = interpolator
+
+        self.augment_ops = self._build_augmentation_ops()
+        self.individual_video_frame_augmentation_ops = self._build_individual_video_frame_augmentation_ops()
+        # Both crop and resize can't be none at the same time.
+        if self.crop_h is None and self.resize_smallest_side is None and \
+                self.resize_h is None:
+            raise ValueError('resize_smallest_side, resize_h_w, '
+                             'and crop_h_w cannot all be missing.')
+        # If resize_smallest_side is given, resize_h_w should not be give.
+        if self.resize_smallest_side is not None:
+            assert self.resize_h is None, \
+                'Cannot have both `resize_smallest_side` and `resize_h_w` set.'
+        if self.resize_smallest_side is None and self.resize_h is None:
+            self.resize_h, self.resize_w = self.crop_h, self.crop_w
+
+    def _build_individual_video_frame_augmentation_ops(self):
+        r"""Builds sequence of augmentation ops that will be applied to each frame in the video independently.
+        Returns:
+            (list of alb.ops): List of augmentation ops.
+        """
+        augs = []
+        for key, value in self.individual_video_frame_aug_list.items():
+            if key == 'random_scale_limit':
+                if type(value) == float:
+                    scale_limit_lb = scale_limit_ub = value
+                    p = 1
+                else:
+                    scale_limit_lb = value['scale_limit_lb']
+                    scale_limit_ub = value['scale_limit_ub']
+                    p = value['p']
+                augs.append(alb.RandomScale(scale_limit=(-scale_limit_lb, scale_limit_ub), p=p))
+            elif key == 'random_crop_h_w':
+                h, w = value.split(',')
+                h, w = int(h), int(w)
+                self.crop_h, self.crop_w = h, w
+                augs.append(alb.PadIfNeeded(min_height=h, min_width=w))
+                augs.append(alb.RandomCrop(h, w, always_apply=True, p=1))
+        return augs
+
+    def _build_augmentation_ops(self):
+        r"""Builds sequence of augmentation ops.
+        Returns:
+            (list of alb.ops): List of augmentation ops.
+        """
+        augs = []
+        for key, value in self.aug_list.items():
+            if key == 'resize_smallest_side':
+                if isinstance(value, int):
+                    self.resize_smallest_side = value
+                else:
+                    h, w = value.split(',')
+                    h, w = int(h), int(w)
+                    self.resize_smallest_side = (h, w)
+            elif key == 'resize_h_w':
+                h, w = value.split(',')
+                h, w = int(h), int(w)
+                self.resize_h, self.resize_w = h, w
+            elif key == 'random_resize_h_w_aspect':
+                aspect_start, aspect_end = value.find('('), value.find(')')
+                aspect = value[aspect_start+1:aspect_end]
+                aspect_min, aspect_max = aspect.split(',')
+                h, w = value[:aspect_start].split(',')[:2]
+                h, w = int(h), int(w)
+                aspect_min, aspect_max = float(aspect_min), float(aspect_max)
+                augs.append(alb.RandomResizedCrop(
+                    h, w, scale=(1, 1),
+                    ratio=(aspect_min, aspect_max), always_apply=True, p=1))
+                self.resize_h, self.resize_w = h, w
+            elif key == 'rotate':
+                augs.append(alb.Rotate(
+                    limit=value, always_apply=True, p=1))
+            elif key == 'random_rotate_90':
+                augs.append(alb.RandomRotate90(always_apply=False, p=0.5))
+            elif key == 'random_scale_limit':
+                augs.append(alb.RandomScale(scale_limit=(0, value), p=1))
+            elif key == 'random_crop_h_w':
+                h, w = value.split(',')
+                h, w = int(h), int(w)
+                self.crop_h, self.crop_w = h, w
+                augs.append(alb.RandomCrop(h, w, always_apply=True, p=1))
+            elif key == 'center_crop_h_w':
+                h, w = value.split(',')
+                h, w = int(h), int(w)
+                self.crop_h, self.crop_w = h, w
+                augs.append(alb.CenterCrop(h, w, always_apply=True, p=1))
+            elif key == 'horizontal_flip':
+                # This is handled separately as we need to keep track if this
+                # was applied in order to correctly modify keypoint data.
+                if value:
+                    augs.append(alb.HorizontalFlip(always_apply=False, p=0.5))
+            # The options below including contrast, blur, motion_blur, compression, gamma
+            # were used during developing face-vid2vid.
+            elif key == 'contrast':
+                brightness_limit = value['brightness_limit']
+                contrast_limit = value['contrast_limit']
+                p = value['p']
+                augs.append(alb.RandomBrightnessContrast(
+                    brightness_limit=brightness_limit, contrast_limit=contrast_limit, p=p))
+            elif key == 'blur':
+                blur_limit = value['blur_limit']
+                p = value['p']
+                augs.append(alb.Blur(blur_limit=blur_limit, p=p))
+            elif key == 'motion_blur':
+                blur_limit = value['blur_limit']
+                p = value['p']
+                augs.append(alb.MotionBlur(blur_limit=blur_limit, p=p))
+            elif key == 'compression':
+                quality_lower = value['quality_lower']
+                p = value['p']
+                augs.append(alb.ImageCompression(quality_lower=quality_lower, p=p))
+            elif key == 'gamma':
+                gamma_limit_lb = value['gamma_limit_lb']
+                gamma_limit_ub = value['gamma_limit_ub']
+                p = value['p']
+                augs.append(alb.RandomGamma(gamma_limit=(gamma_limit_lb, gamma_limit_ub), p=p))
+            elif key == 'max_time_step':
+                self.max_time_step = value
+                assert self.max_time_step >= 1, \
+                    'max_time_step has to be at least 1'
+            else:
+                raise ValueError('Unknown augmentation %s' % (key))
+        return augs
+
+    def _choose_image_key(self, inputs):
+        r"""Choose key to replace with 'image' for input to albumentations.
+
+        Returns:
+            key (str): Chosen key to be replace with 'image'
+        """
+        if 'image' in inputs:
+            return 'image'
+        for data_type in inputs:
+            if data_type in self.image_data_types:
+                return data_type
+
+    def _choose_keypoint_key(self, inputs):
+        r"""Choose key to replace with 'keypoints' for input to albumentations.
+        Returns:
+            key (str): Chosen key to be replace with 'keypoints'
+        """
+        if not self.keypoint_data_types:
+            return None
+        if 'keypoints' in inputs:
+            return 'keypoints'
+        for data_type in inputs:
+            if data_type in self.keypoint_data_types:
+                return data_type
+
+    def _create_augmentation_targets(self, inputs):
+        r"""Create additional targets as required by the albumentation library.
+
+        Args:
+            inputs (dict): Keys are from self.augmentable_data_types. Values can
+                be numpy.ndarray or list of numpy.ndarray
+                (image or list of images).
+        Returns:
+            (dict):
+              - targets (dict): Dict containing mapping of keys to image/mask types.
+              - new_inputs (dict): Dict containing mapping of keys to data.
+        """
+        # Get additional target list.
+        targets, new_inputs = {}, {}
+        for data_type in inputs:
+            if data_type in self.keypoint_data_types:
+                # Keypoint-type.
+                target_type = 'keypoints'
+            elif data_type in self.image_data_types:
+                # Image-type.
+                # Find the target type (image/mask) based on interpolation
+                # method.
+                if self.is_mask[data_type]:
+                    target_type = 'mask'
+                else:
+                    target_type = 'image'
+            else:
+                raise ValueError(
+                    'Data type: %s is not image or keypoint' % (data_type))
+
+            current_data_type_inputs = inputs[data_type]
+            if not isinstance(current_data_type_inputs, list):
+                current_data_type_inputs = [current_data_type_inputs]
+
+            # Create additional_targets and inputs when there are multiples.
+            for idx, new_input in enumerate(current_data_type_inputs):
+                key = data_type
+                if idx > 0:
+                    key = '%s::%05d' % (key, idx)
+                targets[key] = target_type
+                new_inputs[key] = new_input
+
+        return targets, new_inputs
+
+    def _collate_augmented(self, augmented):
+        r"""Collate separated images back into sequence, grouped by keys.
+
+        Args:
+            augmented (dict): Dict containing frames with keys of the form
+            'key', 'key::00001', 'key::00002', ..., 'key::N'.
+        Returns:
+            (dict):
+              - outputs (dict): Dict with list of collated inputs, i.e. frames of
+              - same key are arranged in order ['key', 'key::00001', ..., 'key::N'].
+        """
+        full_keys = sorted(augmented.keys())
+        outputs = {}
+        for full_key in full_keys:
+            if '::' not in full_key:
+                # First occurrence of this key.
+                key = full_key
+                outputs[key] = []
+            else:
+                key = full_key.split('::')[0]
+            outputs[key].append(augmented[full_key])
+        return outputs
+
+    def _get_resize_h_w(self, height, width):
+        r"""Get height and width to resize to, given smallest side.
+
+        Args:
+            height (int): Input image height.
+            width (int): Input image width.
+        Returns:
+            (dict):
+              - height (int): Height to resize image to.
+              - width (int): Width to resize image to.
+        """
+        if self.resize_smallest_side is None:
+            return self.resize_h, self.resize_w
+
+        if isinstance(self.resize_smallest_side, int):
+            resize_smallest_height, resize_smallest_width = self.resize_smallest_side, self.resize_smallest_side
+        else:
+            resize_smallest_height, resize_smallest_width = self.resize_smallest_side
+
+        if height * resize_smallest_width <= width * resize_smallest_height:
+            new_height = resize_smallest_height
+            new_width = int(np.round(new_height * width / float(height)))
+        else:
+            new_width = resize_smallest_width
+            new_height = int(np.round(new_width * height / float(width)))
+        return new_height, new_width
+
+    def _perform_unpaired_augmentation(self, inputs, augment_ops):
+        r"""Perform different data augmentation on different image inputs. Note that this operation only works
+
+        Args:
+            inputs (dict): Keys are from self.image_data_types. Values are list
+                of numpy.ndarray (list of images).
+            augment_ops (list): The augmentation operations.
+        Returns:
+            (dict):
+              - augmented (dict): Augmented inputs, with same keys as inputs.
+              - is_flipped (dict): Flag which tells if images have been LR flipped.
+        """
+        # Process each data type separately as this is unpaired augmentation.
+        is_flipped = {}
+        for data_type in inputs:
+            assert data_type in self.image_data_types
+            augmented, flipped_flag = self._perform_paired_augmentation(
+                {data_type: inputs[data_type]}, augment_ops)
+            inputs[data_type] = augmented[data_type]
+            is_flipped[data_type] = flipped_flag
+        return inputs, is_flipped
+
+    def _perform_paired_augmentation(self, inputs, augment_ops):
+        r"""Perform same data augmentation on all inputs.
+
+        Args:
+            inputs (dict): Keys are from self.augmentable_data_types. Values are
+                list of numpy.ndarray (list of images).
+            augment_ops (list): The augmentation operations.
+
+        Returns:
+            (dict):
+              - augmented (dict): Augmented inputs, with same keys as inputs.
+              - is_flipped (bool): Flag which tells if images have been LR flipped.
+        """
+        # Different data types may have different sizes and we use the largest one as the original size.
+        # Convert PIL images to numpy array.
+        self.original_h, self.original_w = 0, 0
+        for data_type in inputs:
+            if data_type in self.keypoint_data_types or \
+                    data_type not in self.image_data_types:
+                continue
+            for idx in range(len(inputs[data_type])):
+                value = inputs[data_type][idx]
+                # Get resize h, w.
+                w, h = get_image_size(value)
+                self.original_h, self.original_w = max(self.original_h, h), max(self.original_w, w)
+                # self.original_h, self.original_w = h, w
+                # self.resize_h, self.resize_w = self._get_resize_h_w(h, w)
+                # Convert to numpy array with 3 dims (H, W, C).
+                value = np.array(value)
+                if value.ndim == 2:
+                    value = value[..., np.newaxis]
+                inputs[data_type][idx] = value
+        self.resize_h, self.resize_w = self._get_resize_h_w(self.original_h, self.original_w)
+
+        # Add resize op to augmentation ops.
+        aug_ops_with_resize = [alb.Resize(
+            self.resize_h, self.resize_w, interpolation=getattr(cv2, self.interpolator), always_apply=1, p=1
+        )] + augment_ops
+
+        # Create targets.
+        targets, new_inputs = self._create_augmentation_targets(inputs)
+        extra_params = {}
+
+        # Albumentation requires a key called 'image' and
+        # a key called 'keypoints', if any keypoints are being passed in.
+        # Arbitrarily choose one key of image type to be 'image'.
+        chosen_image_key = self._choose_image_key(inputs)
+        new_inputs['image'] = new_inputs.pop(chosen_image_key)
+        targets['image'] = targets.pop(chosen_image_key)
+        # Arbitrarily choose one key of keypoint type to be 'keypoints'.
+        chosen_keypoint_key = self._choose_keypoint_key(inputs)
+        if chosen_keypoint_key is not None:
+            new_inputs['keypoints'] = new_inputs.pop(chosen_keypoint_key)
+            targets['keypoints'] = targets.pop(chosen_keypoint_key)
+            extra_params['keypoint_params'] = alb.KeypointParams(
+                format='xy', remove_invisible=False)
+
+        # Do augmentation.
+        augmented = alb.ReplayCompose(
+            aug_ops_with_resize, additional_targets=targets,
+            **extra_params)(**new_inputs)
+        augmentation_params = augmented.pop('replay')
+
+        # Check if flipping has occurred.
+        is_flipped = False
+        for augmentation_param in augmentation_params['transforms']:
+            if 'HorizontalFlip' in augmentation_param['__class_fullname__']:
+                is_flipped = augmentation_param['applied']
+        self.is_flipped = is_flipped
+
+        # Replace the key 'image' with chosen_image_key, same for 'keypoints'.
+        augmented[chosen_image_key] = augmented.pop('image')
+        if chosen_keypoint_key is not None:
+            augmented[chosen_keypoint_key] = augmented.pop('keypoints')
+
+        # Pack images back into a sequence.
+        augmented = self._collate_augmented(augmented)
+
+        # Convert keypoint types to np.array from list.
+        for data_type in self.keypoint_data_types:
+            augmented[data_type] = np.array(augmented[data_type])
+
+        return augmented, is_flipped
+
+    def perform_augmentation(self, inputs, paired, augment_ops):
+        r"""Entry point for augmentation.
+
+        Args:
+            inputs (dict): Keys are from self.augmentable_data_types. Values are
+                list of numpy.ndarray (list of images).
+            paired (bool): Apply same augmentation to all input keys?
+            augment_ops (list): The augmentation operations.
+        """
+        # Make sure that all inputs are of same size, else trouble will
+        # ensue. This is because different images might have different
+        # aspect ratios.
+        # Check within data type.
+        for data_type in inputs:
+            if data_type in self.keypoint_data_types or \
+                    data_type not in self.image_data_types:
+                continue
+            for idx in range(len(inputs[data_type])):
+                if idx == 0:
+                    w, h = get_image_size(inputs[data_type][idx])
+                else:
+                    this_w, this_h = get_image_size(inputs[data_type][idx])
+                    # assert this_w == w and this_h == h
+                    # assert this_w / (1.0 * this_h) == w / (1.0 * h)
+        # Check across data types.
+        if paired and self.resize_smallest_side is not None:
+            for idx, data_type in enumerate(inputs):
+                if data_type in self.keypoint_data_types or \
+                        data_type not in self.image_data_types:
+                    continue
+        if paired:
+            return self._perform_paired_augmentation(inputs, augment_ops)
+        else:
+            return self._perform_unpaired_augmentation(inputs, augment_ops)
+
+
+def load_from_lmdb(keys, lmdbs):
+    r"""Load keys from lmdb handles.
+
+    Args:
+        keys (dict): This has data_type as key, and a list of paths into LMDB as
+            values.
+        lmdbs (dict): This has data_type as key, and LMDB handle as value.
+    Returns:
+        data (dict): This has data_type as key, and a list of decoded items from
+            LMDBs as value.
+    """
+    data = {}
+    for data_type in keys:
+        if data_type not in data:
+            data[data_type] = []
+        data_type_keys = keys[data_type]
+        if not isinstance(data_type_keys, list):
+            data_type_keys = [data_type_keys]
+        for key in data_type_keys:
+            data[data_type].append(lmdbs[data_type].getitem_by_path(
+                key.encode(), data_type))
+    return data
+
+
+def load_from_folder(keys, handles):
+    r"""Load keys from lmdb handles.
+
+    Args:
+        keys (dict): This has data_type as key, and a list of paths as
+            values.
+        handles (dict): This has data_type as key, and Folder handle as value.
+    Returns:
+        data (dict): This has data_type as key, and a list of decoded items from
+            folders as value.
+    """
+    data = {}
+    for data_type in keys:
+        if data_type not in data:
+            data[data_type] = []
+        data_type_keys = keys[data_type]
+        if not isinstance(data_type_keys, list):
+            data_type_keys = [data_type_keys]
+        for key in data_type_keys:
+            data[data_type].append(handles[data_type].getitem_by_path(
+                key.encode(), data_type))
+    return data
+
+
+def load_from_object_store(keys, handles):
+    r"""Load keys from AWS S3 handles.
+
+    Args:
+        keys (dict): This has data_type as key, and a list of paths as
+            values.
+        handles (dict): This has data_type as key, and Folder handle as value.
+    Returns:
+        data (dict): This has data_type as key, and a list of decoded items from
+            folders as value.
+    """
+    data = {}
+    for data_type in keys:
+        if data_type not in data:
+            data[data_type] = []
+        data_type_keys = keys[data_type]
+        if not isinstance(data_type_keys, list):
+            data_type_keys = [data_type_keys]
+        for key in data_type_keys:
+            while True:
+                try:
+                    data[data_type].append(handles[data_type].getitem_by_path(key, data_type))
+                except Exception as e:
+                    print(e)
+                    print(key, data_type)
+                    print('Retrying in 30 seconds')
+                    time.sleep(30)
+                    continue
+                break
+    return data
+
+
+def get_paired_input_image_channel_number(data_cfg):
+    r"""Get number of channels for the input image.
+
+    Args:
+        data_cfg (obj): Data configuration structure.
+    Returns:
+        num_channels (int): Number of input image channels.
+    """
+    num_channels = 0
+    for ix, data_type in enumerate(data_cfg.input_types):
+        for k in data_type:
+            if k in data_cfg.input_image:
+                num_channels += data_type[k].num_channels
+                print('Concatenate %s for input.' % data_type)
+    print('\tNum. of channels in the input image: %d' % num_channels)
+    return num_channels
+
+
+def get_paired_input_label_channel_number(data_cfg, video=False):
+    r"""Get number of channels for the input label map.
+
+    Args:
+        data_cfg (obj): Data configuration structure.
+        video (bool): Whether we are dealing with video data.
+    Returns:
+        num_channels (int): Number of input label map channels.
+    """
+    num_labels = 0
+    if not hasattr(data_cfg, 'input_labels'):
+        return num_labels
+    for ix, data_type in enumerate(data_cfg.input_types):
+        for k in data_type:
+            if k in data_cfg.input_labels:
+                if hasattr(data_cfg, 'one_hot_num_classes') and k in data_cfg.one_hot_num_classes:
+                    num_labels += data_cfg.one_hot_num_classes[k]
+                    if getattr(data_cfg, 'use_dont_care', False):
+                        num_labels += 1
+                else:
+                    num_labels += data_type[k].num_channels
+            print('Concatenate %s for input.' % data_type)
+
+    if video:
+        num_time_steps = getattr(data_cfg.train, 'initial_sequence_length',
+                                 None)
+        num_labels *= num_time_steps
+        num_labels += get_paired_input_image_channel_number(data_cfg) * (
+            num_time_steps - 1)
+
+    print('\tNum. of channels in the input label: %d' % num_labels)
+    return num_labels
+
+
+def get_class_number(data_cfg):
+    r"""Get number of classes for class-conditional GAN model
+
+    Args:
+        data_cfg (obj): Data configuration structure.
+
+    Returns:
+        (int): Number of classes.
+    """
+    return data_cfg.num_classes
+
+
+def get_crop_h_w(augmentation):
+    r"""Get height and width of crop.
+
+    Args:
+        augmentation (dict): Dict of applied augmentations.
+
+    Returns:
+        (dict):
+          - crop_h (int): Height of the image crop.
+          - crop_w (int): Width of the image crop.
+    """
+    print(augmentation.__dict__.keys())
+    for k in augmentation.__dict__.keys():
+        if 'crop_h_w' in k:
+            filed = augmentation[k]
+            crop_h, crop_w = filed.split(',')
+            crop_h = int(crop_h)
+            crop_w = int(crop_w)
+            # assert crop_w == crop_h, 'This implementation only ' \
+            #                          'supports square-shaped images.'
+            print('\tCrop size: (%d, %d)' % (crop_h, crop_w))
+            return crop_h, crop_w
+    raise AttributeError
+
+
+def get_image_size(x):
+    try:
+        w, h = x.size
+    except Exception:
+        h, w, _ = x.shape
+    return w, h
diff --git a/imaginaire/utils/dataset.py b/imaginaire/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..605bb2e20d2a22a70e38254bea46bd73177c8c5a
--- /dev/null
+++ b/imaginaire/utils/dataset.py
@@ -0,0 +1,120 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import importlib
+
+import torch
+import torch.distributed as dist
+
+from imaginaire.utils.distributed import master_only_print as print
+
+
+def _get_train_and_val_dataset_objects(cfg):
+    r"""Return dataset objects for the training and validation sets.
+
+    Args:
+        cfg (obj): Global configuration file.
+
+    Returns:
+        (dict):
+          - train_dataset (obj): PyTorch training dataset object.
+          - val_dataset (obj): PyTorch validation dataset object.
+    """
+    dataset_module = importlib.import_module(cfg.data.type)
+    train_dataset = dataset_module.Dataset(cfg, is_inference=False)
+    if hasattr(cfg.data.val, 'type'):
+        for key in ['type', 'input_types', 'input_image']:
+            setattr(cfg.data, key, getattr(cfg.data.val, key))
+        dataset_module = importlib.import_module(cfg.data.type)
+    val_dataset = dataset_module.Dataset(cfg, is_inference=True)
+    print('Train dataset length:', len(train_dataset))
+    print('Val dataset length:', len(val_dataset))
+    return train_dataset, val_dataset
+
+
+def _get_data_loader(cfg, dataset, batch_size, not_distributed=False,
+                     shuffle=True, drop_last=True, seed=0):
+    r"""Return data loader .
+
+    Args:
+        cfg (obj): Global configuration file.
+        dataset (obj): PyTorch dataset object.
+        batch_size (int): Batch size.
+        not_distributed (bool): Do not use distributed samplers.
+
+    Return:
+        (obj): Data loader.
+    """
+    not_distributed = not_distributed or not dist.is_initialized()
+    if not_distributed:
+        sampler = None
+    else:
+        sampler = torch.utils.data.distributed.DistributedSampler(dataset, seed=seed)
+    num_workers = getattr(cfg.data, 'num_workers', 8)
+    persistent_workers = getattr(cfg.data, 'persistent_workers', False)
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=batch_size,
+        shuffle=shuffle and (sampler is None),
+        sampler=sampler,
+        pin_memory=True,
+        num_workers=num_workers,
+        drop_last=drop_last,
+        persistent_workers=persistent_workers if num_workers > 0 else False
+    )
+    return data_loader
+
+
+def get_train_and_val_dataloader(cfg, seed=0):
+    r"""Return dataset objects for the training and validation sets.
+
+    Args:
+        cfg (obj): Global configuration file.
+
+    Returns:
+        (dict):
+          - train_data_loader (obj): Train data loader.
+          - val_data_loader (obj): Val data loader.
+    """
+    train_dataset, val_dataset = _get_train_and_val_dataset_objects(cfg)
+    train_data_loader = _get_data_loader(cfg, train_dataset, cfg.data.train.batch_size, drop_last=True, seed=seed)
+    not_distributed = getattr(cfg.data, 'val_data_loader_not_distributed', False)
+    not_distributed = 'video' in cfg.data.type or not_distributed
+    val_data_loader = _get_data_loader(
+        cfg, val_dataset, cfg.data.val.batch_size, not_distributed,
+        shuffle=False, drop_last=getattr(cfg.data.val, 'drop_last', False), seed=seed)
+    return train_data_loader, val_data_loader
+
+
+def _get_test_dataset_object(cfg):
+    r"""Return dataset object for the test set
+
+    Args:
+        cfg (obj): Global configuration file.
+
+    Returns:
+        (obj): PyTorch dataset object.
+    """
+    dataset_module = importlib.import_module(cfg.test_data.type)
+    test_dataset = dataset_module.Dataset(cfg, is_inference=True, is_test=True)
+    return test_dataset
+
+
+def get_test_dataloader(cfg):
+    r"""Return dataset objects for testing
+
+    Args:
+        cfg (obj): Global configuration file.
+
+    Returns:
+        (obj): Val data loader. It may not contain the ground truth.
+    """
+    test_dataset = _get_test_dataset_object(cfg)
+    not_distributed = getattr(
+        cfg.test_data, 'val_data_loader_not_distributed', False)
+    not_distributed = 'video' in cfg.test_data.type or not_distributed
+    test_data_loader = _get_data_loader(
+        cfg, test_dataset, cfg.test_data.test.batch_size, not_distributed,
+        shuffle=False)
+    return test_data_loader
diff --git a/imaginaire/utils/diff_aug.py b/imaginaire/utils/diff_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..3410004fb68aef074750359cbb420e5dd340bd45
--- /dev/null
+++ b/imaginaire/utils/diff_aug.py
@@ -0,0 +1,142 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+# Differentiable Augmentation for Data-Efficient GAN Training
+# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
+# https://arxiv.org/pdf/2006.10738
+# Modified from https://github.com/mit-han-lab/data-efficient-gans
+import torch
+import torch.nn.functional as F
+
+
+def apply_diff_aug(data, keys, aug_policy, inplace=False, **kwargs):
+    r"""Applies differentiable augmentation.
+    Args:
+        data (dict): Input data.
+        keys (list of str): Keys to the data values that we want to apply
+            differentiable augmentation to.
+        aug_policy (str): Type of augmentation(s), ``'color'``,
+            ``'translation'``, or ``'cutout'`` separated by ``','``.
+    """
+    if aug_policy == '':
+        return data
+    data_aug = data if inplace else {}
+    for key, value in data.items():
+        if key in keys:
+            data_aug[key] = diff_aug(data[key], aug_policy, **kwargs)
+        else:
+            data_aug[key] = data[key]
+    return data_aug
+
+
+def diff_aug(x, policy='', channels_first=True, **kwargs):
+    if policy:
+        if not channels_first:
+            x = x.permute(0, 3, 1, 2)
+        for p in policy.split(','):
+            for f in AUGMENT_FNS[p]:
+                x = f(x, **kwargs)
+        if not channels_first:
+            x = x.permute(0, 2, 3, 1)
+        x = x.contiguous()
+    return x
+
+
+def rand_brightness(x, **kwargs):
+    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype,
+                        device=x.device) - 0.5)
+    return x
+
+
+def rand_saturation(x, **kwargs):
+    x_mean = x.mean(dim=1, keepdim=True)
+    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype,
+                                   device=x.device) * 2) + x_mean
+    return x
+
+
+def rand_contrast(x, **kwargs):
+    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
+    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype,
+                                   device=x.device) + 0.5) + x_mean
+    return x
+
+
+def rand_translation(x, ratio=0.125, **kwargs):
+    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(
+        x.size(3) * ratio + 0.5)
+    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1],
+                                  device=x.device)
+    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1],
+                                  device=x.device)
+    # noinspection PyTypeChecker
+    grid_batch, grid_x, grid_y = torch.meshgrid(
+        torch.arange(x.size(0), dtype=torch.long, device=x.device),
+        torch.arange(x.size(2), dtype=torch.long, device=x.device),
+        torch.arange(x.size(3), dtype=torch.long, device=x.device),
+    )
+    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
+    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
+    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
+    x = x_pad.permute(0, 2, 3, 1).contiguous()[
+        grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
+    return x
+
+
+def rand_cutout(x, ratio=0.5, **kwargs):
+    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
+    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2),
+                             size=[x.size(0), 1, 1], device=x.device)
+    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2),
+                             size=[x.size(0), 1, 1], device=x.device)
+    # noinspection PyTypeChecker
+    grid_batch, grid_x, grid_y = torch.meshgrid(
+        torch.arange(x.size(0), dtype=torch.long, device=x.device),
+        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
+        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
+    )
+    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0,
+                         max=x.size(2) - 1)
+    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0,
+                         max=x.size(3) - 1)
+    mask = torch.ones(x.size(0), x.size(2), x.size(3),
+                      dtype=x.dtype, device=x.device)
+    mask[grid_batch, grid_x, grid_y] = 0
+    x = x * mask.unsqueeze(1)
+    return x
+
+
+def rand_translation_scale(x, trans_r=0.125, scale_r=0.125,
+                           mode='bilinear', padding_mode='reflection',
+                           **kwargs):
+    assert x.dim() == 4, "Input must be a 4D tensor."
+    batch_size = x.size(0)
+
+    # Identity transformation.
+    theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat(
+        batch_size, 1, 1)
+
+    # Translation, uniformly sampled from (-trans_r, trans_r).
+    translate = \
+        2 * trans_r * torch.rand(batch_size, 2, device=x.device) - trans_r
+    theta[:, :, 2] += translate
+
+    # Scaling, uniformly sampled from (1-scale_r, 1+scale_r).
+    scale = \
+        2 * scale_r * torch.rand(batch_size, 2, device=x.device) - scale_r
+    theta[:, :, :2] += torch.diag_embed(scale)
+
+    grid = F.affine_grid(theta, x.size())
+    x = F.grid_sample(
+        x.float(), grid.float(), mode=mode, padding_mode=padding_mode)
+    return x
+
+
+AUGMENT_FNS = {
+    'color': [rand_brightness, rand_saturation, rand_contrast],
+    'translation': [rand_translation],
+    'translation_scale': [rand_translation_scale],
+    'cutout': [rand_cutout],
+}
diff --git a/imaginaire/utils/distributed.py b/imaginaire/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7ec9d1099684e58a80a61107fe828e292352002
--- /dev/null
+++ b/imaginaire/utils/distributed.py
@@ -0,0 +1,117 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+import ctypes
+
+import torch
+import torch.distributed as dist
+
+
+def init_dist(local_rank, backend='nccl', **kwargs):
+    r"""Initialize distributed training"""
+    if dist.is_available():
+        if dist.is_initialized():
+            return torch.cuda.current_device()
+        torch.cuda.set_device(local_rank)
+        dist.init_process_group(backend=backend, init_method='env://', **kwargs)
+
+    # Increase the L2 fetch granularity for faster speed.
+    _libcudart = ctypes.CDLL('libcudart.so')
+    # Set device limit on the current device
+    # cudaLimitMaxL2FetchGranularity = 0x05
+    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
+    _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
+    _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
+    # assert pValue.contents.value == 128
+
+
+def get_rank():
+    r"""Get rank of the thread."""
+    rank = 0
+    if dist.is_available():
+        if dist.is_initialized():
+            rank = dist.get_rank()
+    return rank
+
+
+def get_world_size():
+    r"""Get world size. How many GPUs are available in this job."""
+    world_size = 1
+    if dist.is_available():
+        if dist.is_initialized():
+            world_size = dist.get_world_size()
+    return world_size
+
+
+def master_only(func):
+    r"""Apply this function only to the master GPU."""
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        r"""Simple function wrapper for the master function"""
+        if get_rank() == 0:
+            return func(*args, **kwargs)
+        else:
+            return None
+    return wrapper
+
+
+def is_master():
+    r"""check if current process is the master"""
+    return get_rank() == 0
+
+
+def is_local_master():
+    return torch.cuda.current_device() == 0
+
+
+@master_only
+def master_only_print(*args):
+    r"""master-only print"""
+    print(*args)
+
+
+def dist_reduce_tensor(tensor, rank=0, reduce='mean'):
+    r""" Reduce to rank 0 """
+    world_size = get_world_size()
+    if world_size < 2:
+        return tensor
+    with torch.no_grad():
+        dist.reduce(tensor, dst=rank)
+        if get_rank() == rank:
+            if reduce == 'mean':
+                tensor /= world_size
+            elif reduce == 'sum':
+                pass
+            else:
+                raise NotImplementedError
+    return tensor
+
+
+def dist_all_reduce_tensor(tensor, reduce='mean'):
+    r""" Reduce to all ranks """
+    world_size = get_world_size()
+    if world_size < 2:
+        return tensor
+    with torch.no_grad():
+        dist.all_reduce(tensor)
+        if reduce == 'mean':
+            tensor /= world_size
+        elif reduce == 'sum':
+            pass
+        else:
+            raise NotImplementedError
+    return tensor
+
+
+def dist_all_gather_tensor(tensor):
+    r""" gather to all ranks """
+    world_size = get_world_size()
+    if world_size < 2:
+        return [tensor]
+    tensor_list = [
+        torch.ones_like(tensor) for _ in range(dist.get_world_size())]
+    with torch.no_grad():
+        dist.all_gather(tensor_list, tensor)
+    return tensor_list
diff --git a/imaginaire/utils/gpu_affinity.py b/imaginaire/utils/gpu_affinity.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f4e9cb40a5a5f9185e903af55694b5952cfe0ff
--- /dev/null
+++ b/imaginaire/utils/gpu_affinity.py
@@ -0,0 +1,61 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import math
+import os
+import pynvml
+
+pynvml.nvmlInit()
+
+
+def systemGetDriverVersion():
+    r"""Get Driver Version"""
+    return pynvml.nvmlSystemGetDriverVersion()
+
+
+def deviceGetCount():
+    r"""Get number of devices"""
+    return pynvml.nvmlDeviceGetCount()
+
+
+class device(object):
+    r"""Device used for nvml."""
+    _nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
+
+    def __init__(self, device_idx):
+        super().__init__()
+        self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
+
+    def getName(self):
+        r"""Get obect name"""
+        return pynvml.nvmlDeviceGetName(self.handle)
+
+    def getCpuAffinity(self):
+        r"""Get CPU affinity"""
+        affinity_string = ''
+        for j in pynvml.nvmlDeviceGetCpuAffinity(
+                self.handle, device._nvml_affinity_elements):
+            # assume nvml returns list of 64 bit ints
+            affinity_string = '{:064b}'.format(j) + affinity_string
+        affinity_list = [int(x) for x in affinity_string]
+        affinity_list.reverse()  # so core 0 is in 0th element of list
+
+        return [i for i, e in enumerate(affinity_list) if e != 0]
+
+
+def set_affinity(gpu_id=None):
+    r"""Set GPU affinity
+
+    Args:
+        gpu_id (int): Which gpu device.
+    """
+    if gpu_id is None:
+        gpu_id = int(os.getenv('LOCAL_RANK', 0))
+
+    dev = device(gpu_id)
+    os.sched_setaffinity(0, dev.getCpuAffinity())
+
+    # list of ints
+    # representing the logical cores this process is now affinitied with
+    return os.sched_getaffinity(0)
diff --git a/imaginaire/utils/init_weight.py b/imaginaire/utils/init_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d826c27d7fe1ab75bfe565b40531acd02abd2b
--- /dev/null
+++ b/imaginaire/utils/init_weight.py
@@ -0,0 +1,84 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch.nn import init
+
+
+def weights_init(init_type='normal', gain=0.02, bias=None):
+    r"""Initialize weights in the network.
+
+    Args:
+        init_type (str): The name of the initialization scheme.
+        gain (float): The parameter that is required for the initialization
+            scheme.
+        bias (object): If not ``None``, specifies the initialization parameter
+            for bias.
+
+    Returns:
+        (obj): init function to be applied.
+    """
+
+    def init_func(m):
+        r"""Init function
+
+        Args:
+            m: module to be weight initialized.
+        """
+        class_name = m.__class__.__name__
+        if hasattr(m, 'weight') and (
+                class_name.find('Conv') != -1 or
+                class_name.find('Linear') != -1 or
+                class_name.find('Embedding') != -1):
+            lr_mul = getattr(m, 'lr_mul', 1.)
+            gain_final = gain / lr_mul
+            if init_type == 'normal':
+                init.normal_(m.weight.data, 0.0, gain_final)
+            elif init_type == 'xavier':
+                init.xavier_normal_(m.weight.data, gain=gain_final)
+            elif init_type == 'xavier_uniform':
+                init.xavier_uniform_(m.weight.data, gain=gain_final)
+            elif init_type == 'kaiming':
+                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+                with torch.no_grad():
+                    m.weight.data *= gain_final
+            elif init_type == 'kaiming_linear':
+                init.kaiming_normal_(
+                    m.weight.data, a=0, mode='fan_in', nonlinearity='linear'
+                )
+                with torch.no_grad():
+                    m.weight.data *= gain_final
+            elif init_type == 'orthogonal':
+                init.orthogonal_(m.weight.data, gain=gain_final)
+            elif init_type == 'none':
+                pass
+                # m.reset_parameters()
+            else:
+                raise NotImplementedError(
+                    'initialization method [%s] is '
+                    'not implemented' % init_type)
+        if hasattr(m, 'bias') and m.bias is not None:
+            if init_type == 'none':
+                pass
+            elif bias is not None:
+                bias_type = getattr(bias, 'type', 'normal')
+                if bias_type == 'normal':
+                    bias_gain = getattr(bias, 'gain', 0.5)
+                    init.normal_(m.bias.data, 0.0, bias_gain)
+                else:
+                    raise NotImplementedError(
+                        'initialization method [%s] is '
+                        'not implemented' % bias_type)
+            else:
+                init.constant_(m.bias.data, 0.0)
+    return init_func
+
+
+def weights_rescale():
+    def init_func(m):
+        if hasattr(m, 'init_gain'):
+            for name, p in m.named_parameters():
+                if 'output_scale' not in name:
+                    p.data.mul_(m.init_gain)
+    return init_func
diff --git a/imaginaire/utils/io.py b/imaginaire/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f3aa1d737d7a2b0922082c43d4f0c573a482063
--- /dev/null
+++ b/imaginaire/utils/io.py
@@ -0,0 +1,136 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import requests
+import torch.distributed as dist
+import torchvision.utils
+
+from imaginaire.utils.distributed import is_master
+
+
+def save_pilimage_in_jpeg(fullname, output_img):
+    r"""Save PIL Image to JPEG.
+
+    Args:
+        fullname (str): Full save path.
+        output_img (PIL Image): Image to be saved.
+    """
+    dirname = os.path.dirname(fullname)
+    os.makedirs(dirname, exist_ok=True)
+    output_img.save(fullname, 'JPEG', quality=99)
+
+
+def save_intermediate_training_results(
+        visualization_images, logdir, current_epoch, current_iteration):
+    r"""Save intermediate training results for debugging purpose.
+
+    Args:
+        visualization_images (tensor): Image where pixel values are in [-1, 1].
+        logdir (str): Where to save the image.
+        current_epoch (int): Current training epoch.
+        current_iteration (int): Current training iteration.
+    """
+    visualization_images = (visualization_images + 1) / 2
+    output_filename = os.path.join(
+        logdir, 'images',
+        'epoch_{:05}iteration{:09}.jpg'.format(
+            current_epoch, current_iteration))
+    print('Save output images to {}'.format(output_filename))
+    os.makedirs(os.path.dirname(output_filename), exist_ok=True)
+    image_grid = torchvision.utils.make_grid(
+        visualization_images.data, nrow=1, padding=0, normalize=False)
+    torchvision.utils.save_image(image_grid, output_filename, nrow=1)
+
+
+def download_file_from_google_drive(URL, destination):
+    r"""Download a file from google drive.
+
+    Args:
+        URL: GDrive file ID.
+        destination: Path to save the file.
+
+    Returns:
+
+    """
+    download_file(f"https://docs.google.com/uc?export=download&id={URL}", destination)
+
+
+def download_file(URL, destination):
+    r"""Download a file from google drive or pbss by using the url.
+
+    Args:
+        URL: GDrive URL or PBSS pre-signed URL for the checkpoint.
+        destination: Path to save the file.
+
+    Returns:
+
+    """
+    session = requests.Session()
+    response = session.get(URL, stream=True)
+    token = get_confirm_token(response)
+    if token:
+        params = {'confirm': token}
+        response = session.get(URL, params=params, stream=True)
+    save_response_content(response, destination)
+
+
+def get_confirm_token(response):
+    r"""Get confirm token
+
+    Args:
+        response: Check if the file exists.
+
+    Returns:
+
+    """
+    for key, value in response.cookies.items():
+        if key.startswith('download_warning'):
+            return value
+    return None
+
+
+def save_response_content(response, destination):
+    r"""Save response content
+
+    Args:
+        response:
+        destination: Path to save the file.
+
+    Returns:
+
+    """
+    chunk_size = 32768
+    with open(destination, "wb") as f:
+        for chunk in response.iter_content(chunk_size):
+            if chunk:
+                f.write(chunk)
+
+
+def get_checkpoint(checkpoint_path, url=''):
+    r"""Get the checkpoint path. If it does not exist yet, download it from
+    the url.
+
+    Args:
+        checkpoint_path (str): Checkpoint path.
+        url (str): URL to download checkpoint.
+    Returns:
+        (str): Full checkpoint path.
+    """
+    if 'TORCH_HOME' not in os.environ:
+        os.environ['TORCH_HOME'] = os.getcwd()
+    save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints')
+    os.makedirs(save_dir, exist_ok=True)
+    full_checkpoint_path = os.path.join(save_dir, checkpoint_path)
+    if not os.path.exists(full_checkpoint_path):
+        os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True)
+        if is_master():
+            print('Downloading {}'.format(url))
+            if 'pbss.s8k.io' not in url:
+                url = f"https://docs.google.com/uc?export=download&id={url}"
+            download_file(url, full_checkpoint_path)
+    if dist.is_available() and dist.is_initialized():
+        dist.barrier()
+    return full_checkpoint_path
diff --git a/imaginaire/utils/lmdb.py b/imaginaire/utils/lmdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..df40c146b73295598cde04fd94a6869c6a5e69d2
--- /dev/null
+++ b/imaginaire/utils/lmdb.py
@@ -0,0 +1,216 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import glob
+import os
+
+import lmdb
+from tqdm import tqdm
+
+from imaginaire.utils import path
+
+
+def construct_file_path(root, data_type, sequence, filename, ext):
+    """Get file path for our dataset structure."""
+    return '%s/%s/%s/%s.%s' % (root, data_type, sequence, filename, ext)
+
+
+def check_and_add(filepath, key, filepaths, keys, remove_missing=False):
+    r"""Add filepath and key to list of filepaths and keys.
+
+    Args:
+        filepath (str): Filepath to add.
+        key (str): LMDB key for this filepath.
+        filepaths (list): List of filepaths added so far.
+        keys (list): List of keys added so far.
+        remove_missing (bool): If ``True``, removes missing files, otherwise
+            raises an error.
+    Returns:
+        (int): Size of file at filepath.
+    """
+    if not os.path.exists(filepath):
+        print(filepath + ' does not exist.')
+        if remove_missing:
+            return -1
+        else:
+            raise FileNotFoundError(filepath + ' does not exist.')
+    filepaths.append(filepath)
+    keys.append(key)
+    return os.path.getsize(filepath)
+
+
+def write_entry(txn, key, filepath):
+    r"""Dump binary contents of file associated with key to LMDB.
+
+    Args:
+        txn: handle to LMDB.
+        key (str): LMDB key for this filepath.
+        filepath (str): Filepath to add.
+    """
+    with open(filepath, 'rb') as f:
+        data = f.read()
+    txn.put(key.encode('ascii'), data)
+
+
+def build_lmdb(filepaths, keys, output_filepath, map_size, large):
+    r"""Write out lmdb containing (key, contents of filepath) to file.
+
+    Args:
+        filepaths (list): List of filepath strings.
+        keys (list): List of key strings associated with filepaths.
+        output_filepath (str): Location to write LMDB to.
+        map_size (int): Size of LMDB.
+        large (bool): Is the dataset large?
+    """
+    if large:
+        db = lmdb.open(output_filepath, map_size=map_size, writemap=True)
+    else:
+        db = lmdb.open(output_filepath, map_size=map_size)
+    txn = db.begin(write=True)
+    print('Writing LMDB to:', output_filepath)
+    for filepath, key in tqdm(zip(filepaths, keys), total=len(keys)):
+        write_entry(txn, key, filepath)
+    txn.commit()
+
+
+def get_all_filenames_from_list(list_name):
+    r"""Get all filenames from list.
+
+    Args:
+        list_name (str): Path to filename list.
+    Returns:
+        all_filenames (dict): Folder name for key, and filename for values.
+    """
+    with open(list_name, 'rt') as f:
+        lines = f.readlines()
+    lines = [line.strip() for line in lines]
+    all_filenames = dict()
+    for line in lines:
+        if '/' in line:
+            file_str = line.split('/')[0:-1]
+            folder_name = os.path.join(*file_str)
+            image_name = line.split('/')[-1].replace('.jpg', '')
+        else:
+            folder_name = '.'
+            image_name = line.replace('.jpg', '')
+        if folder_name in all_filenames:
+            all_filenames[folder_name].append(image_name)
+        else:
+            all_filenames[folder_name] = [image_name]
+    return all_filenames
+
+
+def get_lmdb_data_types(cfg):
+    r"""Get the data types which should be put in LMDB.
+
+    Args:
+        cfg: Configuration object.
+    """
+    data_types, extensions = [], []
+    for data_type in cfg.data.input_types:
+        name = list(data_type.keys())
+        assert len(name) == 1
+        name = name[0]
+        info = data_type[name]
+
+        if 'computed_on_the_fly' not in info:
+            info['computed_on_the_fly'] = False
+        is_lmdb = not info['computed_on_the_fly']
+        if not is_lmdb:
+            continue
+
+        ext = info['ext']
+        data_types.append(name)
+        extensions.append(ext)
+
+    cfg.data.data_types = data_types
+    cfg.data.extensions = extensions
+    return cfg
+
+
+def create_metadata(data_root=None, cfg=None, paired=None, input_list=''):
+    r"""Main function.
+
+    Args:
+        data_root (str): Location of dataset root.
+        cfg (object): Loaded config object.
+        paired (bool): Paired or unpaired data.
+        input_list (str): Path to filename containing list of inputs.
+    Returns:
+        (tuple):
+          - all_filenames (dict): Key of data type, values with sequences.
+          - extensions (dict): Extension of each data type.
+    """
+    cfg = get_lmdb_data_types(cfg)
+
+    # Get list of all data_types in the dataset.
+    available_data_types = path.get_immediate_subdirectories(data_root)
+    print(available_data_types)
+    required_data_types = cfg.data.data_types
+    data_exts = cfg.data.extensions
+
+    # Find filenames.
+    assert set(required_data_types).issubset(set(available_data_types)), \
+        print(set(required_data_types) - set(available_data_types), 'missing')
+
+    # Find extensions for each data type.
+    extensions = {}
+    for data_type, data_ext in zip(required_data_types, data_exts):
+        extensions[data_type] = data_ext
+    print('Data file extensions:', extensions)
+
+    if paired:
+        if input_list != '':
+            all_filenames = get_all_filenames_from_list(input_list)
+        else:
+            # Get list of all sequences in the dataset.
+            if 'data_keypoint' in required_data_types:
+                search_dir = 'data_keypoint'
+            elif 'data_segmaps' in required_data_types:
+                search_dir = 'data_segmaps'
+            else:
+                search_dir = required_data_types[0]
+            print('Searching in dir: %s' % search_dir)
+            sequences = path.get_recursive_subdirectories(
+                os.path.join(data_root, search_dir),
+                extensions[search_dir])
+            print('Found %d sequences' % (len(sequences)))
+
+            # Get filenames in each sequence.
+            all_filenames = {}
+            for sequence in sequences:
+                folder = '%s/%s/%s/*.%s' % (
+                    data_root, search_dir, sequence,
+                    extensions[search_dir])
+                filenames = sorted(glob.glob(folder))
+                filenames = [
+                    os.path.splitext(os.path.basename(filename))[0] for
+                    filename in filenames]
+                all_filenames[sequence] = filenames
+            total_filenames = [len(filenames)
+                               for _, filenames in all_filenames.items()]
+            print('Found %d files' % (sum(total_filenames)))
+    else:
+        # Get sequences in each data type.
+        all_filenames = {}
+        for data_type in required_data_types:
+            all_filenames[data_type] = {}
+            sequences = path.get_recursive_subdirectories(
+                os.path.join(data_root, data_type), extensions[data_type])
+
+            # Get filenames in each sequence.
+            total_filenames = 0
+            for sequence in sequences:
+                folder = '%s/%s/%s/*.%s' % (
+                    data_root, data_type, sequence, extensions[data_type])
+                filenames = sorted(glob.glob(folder))
+                filenames = [
+                    os.path.splitext(os.path.basename(filename))[0] for
+                    filename in filenames]
+                all_filenames[data_type][sequence] = filenames
+                total_filenames += len(filenames)
+            print('Data type: %s, Found %d sequences, Found %d files' %
+                  (data_type, len(sequences), total_filenames))
+
+    return all_filenames, extensions
diff --git a/imaginaire/utils/logging.py b/imaginaire/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..449e0b7b892d0e11baccc5c2c2333afec8501422
--- /dev/null
+++ b/imaginaire/utils/logging.py
@@ -0,0 +1,51 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import datetime
+import os
+
+from imaginaire.utils.distributed import master_only
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.meters import set_summary_writer
+
+
+def get_date_uid():
+    """Generate a unique id based on date.
+    Returns:
+        str: Return uid string, e.g. '20171122171307111552'.
+    """
+    return str(datetime.datetime.now().strftime("%Y_%m%d_%H%M_%S"))
+
+
+def init_logging(config_path, logdir):
+    r"""Create log directory for storing checkpoints and output images.
+
+    Args:
+        config_path (str): Path to the configuration file.
+        logdir (str): Log directory name
+    Returns:
+        str: Return log dir
+    """
+    config_file = os.path.basename(config_path)
+    root_dir = 'logs'
+    date_uid = get_date_uid()
+    # example: logs/2019_0125_1047_58_spade_cocostuff
+    log_file = '_'.join([date_uid, os.path.splitext(config_file)[0]])
+    if logdir is None:
+        logdir = os.path.join(root_dir, log_file)
+    return date_uid, logdir
+
+
+@master_only
+def make_logging_dir(logdir):
+    r"""Create the logging directory
+
+    Args:
+        logdir (str): Log directory name
+    """
+    print('Make folder {}'.format(logdir))
+    os.makedirs(logdir, exist_ok=True)
+    tensorboard_dir = os.path.join(logdir, 'tensorboard')
+    os.makedirs(tensorboard_dir, exist_ok=True)
+    set_summary_writer(tensorboard_dir)
diff --git a/imaginaire/utils/meters.py b/imaginaire/utils/meters.py
new file mode 100644
index 0000000000000000000000000000000000000000..3befb7b1e5fc44c00d3fe29092e75777afa64caa
--- /dev/null
+++ b/imaginaire/utils/meters.py
@@ -0,0 +1,149 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import math
+from datetime import timedelta
+
+import torch
+import wandb
+from wandb import AlertLevel
+from torch.utils.tensorboard import SummaryWriter
+
+from imaginaire.utils.distributed import master_only, dist_all_reduce_tensor, \
+    is_master, get_rank
+
+from imaginaire.utils.distributed import master_only_print as print
+
+LOG_WRITER = None
+LOG_DIR = None
+
+
+@torch.no_grad()
+def sn_reshape_weight_to_matrix(weight):
+    r"""Reshape weight to obtain the matrix form.
+
+    Args:
+        weight (Parameters): pytorch layer parameter tensor.
+    """
+    weight_mat = weight
+    height = weight_mat.size(0)
+    return weight_mat.reshape(height, -1)
+
+
+@torch.no_grad()
+def get_weight_stats(mod):
+    r"""Get weight state
+
+    Args:
+         mod: Pytorch module
+    """
+    if mod.weight_orig.grad is not None:
+        grad_norm = mod.weight_orig.grad.data.norm().item()
+    else:
+        grad_norm = 0.
+    weight_norm = mod.weight_orig.data.norm().item()
+    weight_mat = sn_reshape_weight_to_matrix(mod.weight_orig)
+    sigma = torch.sum(mod.weight_u * torch.mv(weight_mat, mod.weight_v))
+    return grad_norm, weight_norm, sigma
+
+
+@master_only
+def set_summary_writer(log_dir):
+    r"""Set summary writer
+
+    Args:
+        log_dir (str): Log directory.
+    """
+    global LOG_DIR, LOG_WRITER
+    LOG_DIR = log_dir
+    LOG_WRITER = SummaryWriter(log_dir=log_dir)
+
+
+def write_summary(name, summary, step, hist=False):
+    """Utility function for write summary to log_writer.
+    """
+    global LOG_WRITER
+    lw = LOG_WRITER
+    if lw is None:
+        raise Exception("Log writer not set.")
+    if hist:
+        lw.add_histogram(name, summary, step)
+    else:
+        lw.add_scalar(name, summary, step)
+
+
+class Meter(object):
+    """Meter is to keep track of statistics along steps.
+    Meters write values for purpose like printing average values.
+    Meters can be flushed to log files (i.e. TensorBoard for now)
+    regularly.
+
+    Args:
+        name (str): the name of meter
+        reduce (bool): If ``True``, perform a distributed reduce for the log
+            values across all GPUs.
+    """
+
+    def __init__(self, name, reduce=True):
+        self.name = name
+        self.reduce = reduce
+        self.values = []
+
+    def reset(self):
+        r"""Reset the meter values"""
+        if not self.reduce and get_rank() != 0:
+            return
+        self.values = []
+
+    def write(self, value):
+        r"""Record the value"""
+        if not self.reduce and get_rank() != 0:
+            return
+        if value is not None:
+            self.values.append(value)
+
+    def flush(self, step):
+        r"""Write the value in the tensorboard.
+
+        Args:
+            step (int): Epoch or iteration number.
+        """
+        if not self.reduce and get_rank() != 0:
+            return
+        values = torch.tensor(self.values, device="cuda")
+        if self.reduce:
+            values = dist_all_reduce_tensor(values)
+
+        if not all(math.isfinite(x) for x in values):
+            print("meter {} contained a nan or inf.".format(self.name))
+            if is_master():
+                wandb.alert(
+                    title='NaN',
+                    text=f'Meter {self.name} contained a nan or inf.',
+                    level=AlertLevel.WARN,
+                    wait_duration=timedelta(minutes=120)
+                )
+        filtered_values = list(filter(lambda x: math.isfinite(x), self.values))
+        if float(len(filtered_values)) != 0:
+            value = float(sum(filtered_values)) / float(len(filtered_values))
+            if is_master():
+                write_summary(self.name, value, step)
+                wandb.log({self.name: value}, step=step)
+        self.reset()
+
+    @master_only
+    def write_image(self, img_grid, step):
+        r"""Write the value in the tensorboard.
+
+        Args:
+            img_grid:
+            step (int): Epoch or iteration number.
+        """
+        if not self.reduce and get_rank() != 0:
+            return
+        global LOG_WRITER
+        lw = LOG_WRITER
+        if lw is None:
+            raise Exception("Log writer not set.")
+        lw.add_image("Visualizations", img_grid, step)
diff --git a/imaginaire/utils/misc.py b/imaginaire/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..11ae68652975a2c7e2396c4d7eda2fa2f61fe5a5
--- /dev/null
+++ b/imaginaire/utils/misc.py
@@ -0,0 +1,269 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""Miscellaneous utils."""
+import collections
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+string_classes = (str, bytes)
+
+
+def split_labels(labels, label_lengths):
+    r"""Split concatenated labels into their parts.
+
+    Args:
+        labels (torch.Tensor): Labels obtained through concatenation.
+        label_lengths (OrderedDict): Containing order of labels & their lengths.
+
+    Returns:
+
+    """
+    assert isinstance(label_lengths, OrderedDict)
+    start = 0
+    outputs = {}
+    for data_type, length in label_lengths.items():
+        end = start + length
+        if labels.dim() == 5:
+            outputs[data_type] = labels[:, :, start:end]
+        elif labels.dim() == 4:
+            outputs[data_type] = labels[:, start:end]
+        elif labels.dim() == 3:
+            outputs[data_type] = labels[start:end]
+        start = end
+    return outputs
+
+
+def requires_grad(model, require=True):
+    r""" Set a model to require gradient or not.
+
+    Args:
+        model (nn.Module): Neural network model.
+        require (bool): Whether the network requires gradient or not.
+
+    Returns:
+
+    """
+    for p in model.parameters():
+        p.requires_grad = require
+
+
+def to_device(data, device):
+    r"""Move all tensors inside data to device.
+
+    Args:
+        data (dict, list, or tensor): Input data.
+        device (str): 'cpu' or 'cuda'.
+    """
+    assert device in ['cpu', 'cuda']
+    if isinstance(data, torch.Tensor):
+        data = data.to(torch.device(device))
+        return data
+    elif isinstance(data, collections.abc.Mapping):
+        return {key: to_device(data[key], device) for key in data}
+    elif isinstance(data, collections.abc.Sequence) and \
+            not isinstance(data, string_classes):
+        return [to_device(d, device) for d in data]
+    else:
+        return data
+
+
+def to_cuda(data):
+    r"""Move all tensors inside data to gpu.
+
+    Args:
+        data (dict, list, or tensor): Input data.
+    """
+    return to_device(data, 'cuda')
+
+
+def to_cpu(data):
+    r"""Move all tensors inside data to cpu.
+
+    Args:
+        data (dict, list, or tensor): Input data.
+    """
+    return to_device(data, 'cpu')
+
+
+def to_half(data):
+    r"""Move all floats to half.
+
+    Args:
+        data (dict, list or tensor): Input data.
+    """
+    if isinstance(data, torch.Tensor) and torch.is_floating_point(data):
+        data = data.half()
+        return data
+    elif isinstance(data, collections.abc.Mapping):
+        return {key: to_half(data[key]) for key in data}
+    elif isinstance(data, collections.abc.Sequence) and \
+            not isinstance(data, string_classes):
+        return [to_half(d) for d in data]
+    else:
+        return data
+
+
+def to_float(data):
+    r"""Move all halfs to float.
+
+    Args:
+        data (dict, list or tensor): Input data.
+    """
+    if isinstance(data, torch.Tensor) and torch.is_floating_point(data):
+        data = data.float()
+        return data
+    elif isinstance(data, collections.abc.Mapping):
+        return {key: to_float(data[key]) for key in data}
+    elif isinstance(data, collections.abc.Sequence) and \
+            not isinstance(data, string_classes):
+        return [to_float(d) for d in data]
+    else:
+        return data
+
+
+def to_channels_last(data):
+    r"""Move all data to ``channels_last`` format.
+
+    Args:
+        data (dict, list or tensor): Input data.
+    """
+    if isinstance(data, torch.Tensor):
+        if data.dim() == 4:
+            data = data.to(memory_format=torch.channels_last)
+        return data
+    elif isinstance(data, collections.abc.Mapping):
+        return {key: to_channels_last(data[key]) for key in data}
+    elif isinstance(data, collections.abc.Sequence) and \
+            not isinstance(data, string_classes):
+        return [to_channels_last(d) for d in data]
+    else:
+        return data
+
+
+def slice_tensor(data, start, end):
+    r"""Slice all tensors from start to end.
+    Args:
+        data (dict, list or tensor): Input data.
+    """
+    if isinstance(data, torch.Tensor):
+        data = data[start:end]
+        return data
+    elif isinstance(data, collections.abc.Mapping):
+        return {key: slice_tensor(data[key], start, end) for key in data}
+    elif isinstance(data, collections.abc.Sequence) and \
+            not isinstance(data, string_classes):
+        return [slice_tensor(d, start, end) for d in data]
+    else:
+        return data
+
+
+def get_and_setattr(cfg, name, default):
+    r"""Get attribute with default choice. If attribute does not exist, set it
+    using the default value.
+
+    Args:
+        cfg (obj) : Config options.
+        name (str) : Attribute name.
+        default (obj) : Default attribute.
+
+    Returns:
+        (obj) : Desired attribute.
+    """
+    if not hasattr(cfg, name) or name not in cfg.__dict__:
+        setattr(cfg, name, default)
+    return getattr(cfg, name)
+
+
+def get_nested_attr(cfg, attr_name, default):
+    r"""Iteratively try to get the attribute from cfg. If not found, return
+    default.
+
+    Args:
+        cfg (obj): Config file.
+        attr_name (str): Attribute name (e.g. XXX.YYY.ZZZ).
+        default (obj): Default return value for the attribute.
+
+    Returns:
+        (obj): Attribute value.
+    """
+    names = attr_name.split('.')
+    atr = cfg
+    for name in names:
+        if not hasattr(atr, name):
+            return default
+        atr = getattr(atr, name)
+    return atr
+
+
+def gradient_norm(model):
+    r"""Return the gradient norm of model.
+
+    Args:
+        model (PyTorch module): Your network.
+
+    """
+    total_norm = 0
+    for p in model.parameters():
+        if p.grad is not None:
+            param_norm = p.grad.norm(2)
+            total_norm += param_norm.item() ** 2
+    return total_norm ** (1. / 2)
+
+
+def random_shift(x, offset=0.05, mode='bilinear', padding_mode='reflection'):
+    r"""Randomly shift the input tensor.
+
+    Args:
+        x (4D tensor): The input batch of images.
+        offset (int): The maximum offset ratio that is between [0, 1].
+        The maximum shift is offset * image_size for each direction.
+        mode (str): The resample mode for 'F.grid_sample'.
+        padding_mode (str): The padding mode for 'F.grid_sample'.
+
+    Returns:
+        x (4D tensor) : The randomly shifted image.
+    """
+    assert x.dim() == 4, "Input must be a 4D tensor."
+    batch_size = x.size(0)
+    theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat(
+        batch_size, 1, 1)
+    theta[:, :, 2] = 2 * offset * torch.rand(batch_size, 2) - offset
+    grid = F.affine_grid(theta, x.size())
+    x = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode)
+    return x
+
+
+# def truncated_gaussian(threshold, size, seed=None, device=None):
+#     r"""Apply the truncated gaussian trick to trade diversity for quality
+#
+#     Args:
+#         threshold (float): Truncation threshold.
+#         size (list of integer): Tensor size.
+#         seed (int): Random seed.
+#         device:
+#     """
+#     state = None if seed is None else np.random.RandomState(seed)
+#     values = truncnorm.rvs(-threshold, threshold,
+#                            size=size, random_state=state)
+#     return torch.tensor(values, device=device).float()
+
+
+def apply_imagenet_normalization(input):
+    r"""Normalize using ImageNet mean and std.
+
+    Args:
+        input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1].
+
+    Returns:
+        Normalized inputs using the ImageNet normalization.
+    """
+    # normalize the input back to [0, 1]
+    normalized_input = (input + 1) / 2
+    # normalize the input using the ImageNet mean and std
+    mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
+    std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
+    output = (normalized_input - mean) / std
+    return output
diff --git a/imaginaire/utils/model_average.py b/imaginaire/utils/model_average.py
new file mode 100644
index 0000000000000000000000000000000000000000..470428147c9e6cc55df74ad14c010a15cb874a29
--- /dev/null
+++ b/imaginaire/utils/model_average.py
@@ -0,0 +1,215 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import copy
+
+import torch
+from torch import nn
+from imaginaire.layers.weight_norm import remove_weight_norms
+from imaginaire.utils.misc import requires_grad
+
+
+def reset_batch_norm(m):
+    r"""Reset batch norm statistics
+
+    Args:
+        m: Pytorch module
+    """
+    if hasattr(m, 'reset_running_stats'):
+        m.reset_running_stats()
+
+
+def calibrate_batch_norm_momentum(m):
+    r"""Calibrate batch norm momentum
+
+    Args:
+        m: Pytorch module
+    """
+    if hasattr(m, 'reset_running_stats'):
+        # if m._get_name() == 'SyncBatchNorm':
+        if 'BatchNorm' in m._get_name():
+            m.momentum = 1.0 / float(m.num_batches_tracked + 1)
+
+
+class ModelAverage(nn.Module):
+    r"""In this model average implementation, the spectral layers are
+    absorbed in the model parameter by default. If such options are
+    turned on, be careful with how you do the training. Remember to
+    re-estimate the batch norm parameters before using the model.
+
+    Args:
+        module (torch nn module): Torch network.
+        beta (float): Moving average weights. How much we weight the past.
+        start_iteration (int): From which iteration, we start the update.
+        remove_sn (bool): Whether we remove the spectral norm when we it.
+    """
+    def __init__(
+            self, module, beta=0.9999, start_iteration=1000,
+            remove_wn_wrapper=True
+    ):
+        super(ModelAverage, self).__init__()
+        self.module = module
+        # A shallow copy creates a new object which stores the reference of
+        # the original elements.
+        # A deep copy creates a new object and recursively adds the copies of
+        # nested objects present in the original elements.
+        self.averaged_model = copy.deepcopy(self.module).to('cuda')
+        self.beta = beta
+        self.remove_wn_wrapper = remove_wn_wrapper
+        self.start_iteration = start_iteration
+        # This buffer is to track how many iterations has the model been
+        # trained for. We will ignore the first $(start_iterations) and start
+        # the averaging after.
+        self.register_buffer('num_updates_tracked',
+                             torch.tensor(0, dtype=torch.long))
+        self.num_updates_tracked = self.num_updates_tracked.to('cuda')
+        # if self.remove_sn:
+        #     # If we want to remove the spectral norm, we first copy the
+        #     # weights to the moving average model.
+        #     self.copy_s2t()
+        #
+        #     def fn_remove_sn(m):
+        #         r"""Remove spectral norm."""
+        #         if hasattr(m, 'weight_orig'):
+        #             remove_spectral_norm(m)
+        #
+        #     self.averaged_model.apply(fn_remove_sn)
+        #     self.dim = 0
+        if self.remove_wn_wrapper:
+            self.copy_s2t()
+
+            self.averaged_model.apply(remove_weight_norms)
+            self.dim = 0
+        else:
+            self.averaged_model.eval()
+
+        # Averaged model does not require grad.
+        requires_grad(self.averaged_model, False)
+
+    def forward(self, *inputs, **kwargs):
+        r"""PyTorch module forward function overload."""
+        return self.module(*inputs, **kwargs)
+
+    @torch.no_grad()
+    def update_average(self):
+        r"""Update the moving average."""
+        self.num_updates_tracked += 1
+        if self.num_updates_tracked <= self.start_iteration:
+            beta = 0.
+        else:
+            beta = self.beta
+        source_dict = self.module.state_dict()
+        target_dict = self.averaged_model.state_dict()
+        for key in target_dict:
+            if 'num_batches_tracked' in key:
+                continue
+            if self.remove_wn_wrapper:
+                if key.endswith('weight'):
+                    # This is a weight parameter.
+                    if key + '_ori' in source_dict:
+                        # This parameter has scaled lr.
+                        source_param = \
+                            source_dict[key + '_ori'] * \
+                            source_dict[key + '_scale']
+                    elif key + '_orig' in source_dict:
+                        # This parameter has spectral norm
+                        # but not scaled lr.
+                        source_param = source_dict[key + '_orig']
+                    elif key in source_dict:
+                        # This parameter does not have
+                        # weight normalization wrappers.
+                        source_param = source_dict[key]
+                    else:
+                        raise ValueError(
+                            f"{key} required in the averaged model but not "
+                            f"found in the regular model."
+                        )
+                    source_param = source_param.detach()
+
+                    if key + '_orig' in source_dict:
+                        # This parameter has spectral norm.
+                        source_param = self.sn_compute_weight(
+                            source_param,
+                            source_dict[key + '_u'],
+                            source_dict[key + '_v'],
+                        )
+                elif key.endswith('bias') and key + '_ori' in source_dict:
+                    # This is a bias parameter and has scaled lr.
+                    source_param = source_dict[key + '_ori'] * \
+                                   source_dict[key + '_scale']
+                else:
+                    # This is a normal parameter.
+                    source_param = source_dict[key]
+                target_dict[key].data.mul_(beta).add_(
+                    source_param.data, alpha=1 - beta
+                )
+            else:
+                target_dict[key].data.mul_(beta).add_(
+                    source_dict[key].data, alpha=1 - beta
+                )
+
+    @torch.no_grad()
+    def copy_t2s(self):
+        r"""Copy the original weights to the moving average weights."""
+        target_dict = self.module.state_dict()
+        source_dict = self.averaged_model.state_dict()
+        beta = 0.
+        for key in source_dict:
+            target_dict[key].data.copy_(
+                target_dict[key].data * beta +
+                source_dict[key].data * (1 - beta))
+
+    @torch.no_grad()
+    def copy_s2t(self):
+        r""" Copy state_dictionary from source to target.
+        Here source is the regular module and the target is the moving
+        average module. Basically, we will copy weights in the regular module
+        to the moving average module.
+        """
+        source_dict = self.module.state_dict()
+        target_dict = self.averaged_model.state_dict()
+        beta = 0.
+        for key in source_dict:
+            target_dict[key].data.copy_(
+                target_dict[key].data * beta +
+                source_dict[key].data * (1 - beta))
+
+    def __repr__(self):
+        r"""Returns a string that holds a printable representation of an
+        object"""
+        return self.module.__repr__()
+
+    def sn_reshape_weight_to_matrix(self, weight):
+        r"""Reshape weight to obtain the matrix form.
+
+        Args:
+            weight (Parameters): pytorch layer parameter tensor.
+
+        Returns:
+            (Parameters): Reshaped weight matrix
+        """
+        weight_mat = weight
+        if self.dim != 0:
+            # permute dim to front
+            weight_mat = weight_mat.permute(
+                self.dim,
+                *[d for d in range(weight_mat.dim()) if d != self.dim])
+        height = weight_mat.size(0)
+        return weight_mat.reshape(height, -1)
+
+    def sn_compute_weight(self, weight, u, v):
+        r"""Compute the spectral norm normalized matrix.
+
+        Args:
+            weight (Parameters): pytorch layer parameter tensor.
+            u (tensor): left singular vectors.
+            v (tensor) right singular vectors
+
+        Returns:
+            (Parameters): weight parameter object.
+        """
+        weight_mat = self.sn_reshape_weight_to_matrix(weight)
+        sigma = torch.sum(u * torch.mv(weight_mat, v))
+        weight = weight / sigma
+        return weight
diff --git a/imaginaire/utils/path.py b/imaginaire/utils/path.py
new file mode 100644
index 0000000000000000000000000000000000000000..e576fc91e66d7c1931b0fb3f349363b49f62c8d5
--- /dev/null
+++ b/imaginaire/utils/path.py
@@ -0,0 +1,36 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""Utils to deal with directories and paths."""
+
+import glob
+import os
+
+
+def get_immediate_subdirectories(input_dir):
+    """List dirs immediately under input_dir.
+
+    Args:
+        input_dir (str): Directory to list children of.
+        Returns:
+        (list): List of directory paths relative to input_dir.
+    """
+    return sorted([name for name in os.listdir(input_dir)
+                   if os.path.isdir(os.path.join(input_dir, name))])
+
+
+def get_recursive_subdirectories(input_dir, ext):
+    """List dirs recursively under input_dir.
+
+    Args:
+        input_dir (str): Directory to list children of.
+        ext (str): Extension of files expected in this directory.
+        Returns:
+        (list): List of directory paths relative to input_dir.
+    """
+    lines = glob.glob('%s/**/*.%s' % (input_dir, ext), recursive=True)
+    dirpaths = [os.path.dirname(item) for item in lines]
+    dirpaths = [os.path.relpath(item, input_dir) for item in dirpaths]
+    dirpaths = sorted(list(set(dirpaths)))
+    return dirpaths
diff --git a/imaginaire/utils/trainer.py b/imaginaire/utils/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eb8593ac2904433ce408619d7f28c6fa80ababd
--- /dev/null
+++ b/imaginaire/utils/trainer.py
@@ -0,0 +1,341 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import importlib
+import random
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.optim import SGD, Adam, RMSprop, lr_scheduler
+
+from imaginaire.optimizers import Fromage, Madam
+from imaginaire.utils.distributed import get_rank, get_world_size
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.init_weight import weights_init, weights_rescale
+from imaginaire.utils.model_average import ModelAverage
+
+
+def set_random_seed(seed, by_rank=False):
+    r"""Set random seeds for everything.
+
+    Args:
+        seed (int): Random seed.
+        by_rank (bool):
+    """
+    if by_rank:
+        seed += get_rank()
+    print(f"Using random seed {seed}")
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+
+def get_trainer(cfg, net_G, net_D=None,
+                opt_G=None, opt_D=None,
+                sch_G=None, sch_D=None,
+                train_data_loader=None,
+                val_data_loader=None):
+    """Return the trainer object.
+
+    Args:
+        cfg (Config): Loaded config object.
+        net_G (obj): Generator network object.
+        net_D (obj): Discriminator network object.
+        opt_G (obj): Generator optimizer object.
+        opt_D (obj): Discriminator optimizer object.
+        sch_G (obj): Generator optimizer scheduler object.
+        sch_D (obj): Discriminator optimizer scheduler object.
+        train_data_loader (obj): Train data loader.
+        val_data_loader (obj): Validation data loader.
+
+    Returns:
+        (obj): Trainer object.
+    """
+    trainer_lib = importlib.import_module(cfg.trainer.type)
+    trainer = trainer_lib.Trainer(cfg, net_G, net_D,
+                                  opt_G, opt_D,
+                                  sch_G, sch_D,
+                                  train_data_loader, val_data_loader)
+    return trainer
+
+
+def get_model_optimizer_and_scheduler(cfg, seed=0):
+    r"""Return the networks, the optimizers, and the schedulers. We will
+    first set the random seed to a fixed value so that each GPU copy will be
+    initialized to have the same network weights. We will then use different
+    random seeds for different GPUs. After this we will wrap the generator
+    with a moving average model if applicable. It is followed by getting the
+    optimizers and data distributed data parallel wrapping.
+
+    Args:
+        cfg (obj): Global configuration.
+        seed (int): Random seed.
+
+    Returns:
+        (dict):
+          - net_G (obj): Generator network object.
+          - net_D (obj): Discriminator network object.
+          - opt_G (obj): Generator optimizer object.
+          - opt_D (obj): Discriminator optimizer object.
+          - sch_G (obj): Generator optimizer scheduler object.
+          - sch_D (obj): Discriminator optimizer scheduler object.
+    """
+    # We first set the random seed to be the same so that we initialize each
+    # copy of the network in exactly the same way so that they have the same
+    # weights and other parameters. The true seed will be the seed.
+    set_random_seed(seed, by_rank=False)
+    # Construct networks
+    lib_G = importlib.import_module(cfg.gen.type)
+    lib_D = importlib.import_module(cfg.dis.type)
+    net_G = lib_G.Generator(cfg.gen, cfg.data)
+    net_D = lib_D.Discriminator(cfg.dis, cfg.data)
+    print('Initialize net_G and net_D weights using '
+          'type: {} gain: {}'.format(cfg.trainer.init.type,
+                                     cfg.trainer.init.gain))
+    init_bias = getattr(cfg.trainer.init, 'bias', None)
+    net_G.apply(weights_init(
+        cfg.trainer.init.type, cfg.trainer.init.gain, init_bias))
+    net_D.apply(weights_init(
+        cfg.trainer.init.type, cfg.trainer.init.gain, init_bias))
+    net_G.apply(weights_rescale())
+    net_D.apply(weights_rescale())
+    # for name, p in net_G.named_parameters():
+    #     if 'modulation' in name and 'bias' in name:
+    #         nn.init.constant_(p.data, 1.)
+    net_G = net_G.to('cuda')
+    net_D = net_D.to('cuda')
+    # Different GPU copies of the same model will receive noises
+    # initialized with different random seeds (if applicable) thanks to the
+    # set_random_seed command (GPU #K has random seed = args.seed + K).
+    set_random_seed(seed, by_rank=True)
+    print('net_G parameter count: {:,}'.format(_calculate_model_size(net_G)))
+    print('net_D parameter count: {:,}'.format(_calculate_model_size(net_D)))
+
+    # Optimizer
+    opt_G = get_optimizer(cfg.gen_opt, net_G)
+    opt_D = get_optimizer(cfg.dis_opt, net_D)
+
+    net_G, net_D, opt_G, opt_D = \
+        wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D)
+
+    # Scheduler
+    sch_G = get_scheduler(cfg.gen_opt, opt_G)
+    sch_D = get_scheduler(cfg.dis_opt, opt_D)
+
+    return net_G, net_D, opt_G, opt_D, sch_G, sch_D
+
+
+def wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D):
+    r"""Wrap the networks and the optimizers with AMP DDP and (optionally)
+    model average.
+
+    Args:
+        cfg (obj): Global configuration.
+        net_G (obj): Generator network object.
+        net_D (obj): Discriminator network object.
+        opt_G (obj): Generator optimizer object.
+        opt_D (obj): Discriminator optimizer object.
+
+    Returns:
+        (dict):
+          - net_G (obj): Generator network object.
+          - net_D (obj): Discriminator network object.
+          - opt_G (obj): Generator optimizer object.
+          - opt_D (obj): Discriminator optimizer object.
+    """
+    # Apply model average wrapper.
+    if cfg.trainer.model_average_config.enabled:
+        if hasattr(cfg.trainer.model_average_config, 'g_smooth_img'):
+            # Specifies half-life of the running average of generator weights.
+            cfg.trainer.model_average_config.beta = \
+                0.5 ** (cfg.data.train.batch_size *
+                        get_world_size() / cfg.trainer.model_average_config.g_smooth_img)
+            print(f"EMA Decay Factor: {cfg.trainer.model_average_config.beta}")
+        net_G = ModelAverage(net_G, cfg.trainer.model_average_config.beta,
+                             cfg.trainer.model_average_config.start_iteration,
+                             cfg.trainer.model_average_config.remove_sn)
+    if cfg.trainer.model_average_config.enabled:
+        net_G_module = net_G.module
+    else:
+        net_G_module = net_G
+    if hasattr(net_G_module, 'custom_init'):
+        net_G_module.custom_init()
+
+    net_G = _wrap_model(cfg, net_G)
+    net_D = _wrap_model(cfg, net_D)
+    return net_G, net_D, opt_G, opt_D
+
+
+def _calculate_model_size(model):
+    r"""Calculate number of parameters in a PyTorch network.
+
+    Args:
+        model (obj): PyTorch network.
+
+    Returns:
+        (int): Number of parameters.
+    """
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+class WrappedModel(nn.Module):
+    r"""Dummy wrapping the module.
+    """
+
+    def __init__(self, module):
+        super(WrappedModel, self).__init__()
+        self.module = module
+
+    def forward(self, *args, **kwargs):
+        r"""PyTorch module forward function overload."""
+        return self.module(*args, **kwargs)
+
+
+def _wrap_model(cfg, model):
+    r"""Wrap a model for distributed data parallel training.
+
+    Args:
+        model (obj): PyTorch network model.
+
+    Returns:
+        (obj): Wrapped PyTorch network model.
+    """
+    if torch.distributed.is_available() and dist.is_initialized():
+        # ddp = cfg.trainer.distributed_data_parallel
+        find_unused_parameters = cfg.trainer.distributed_data_parallel_params.find_unused_parameters
+        return torch.nn.parallel.DistributedDataParallel(
+            model,
+            device_ids=[cfg.local_rank],
+            output_device=cfg.local_rank,
+            find_unused_parameters=find_unused_parameters,
+            broadcast_buffers=False
+        )
+        # if ddp == 'pytorch':
+        #     return torch.nn.parallel.DistributedDataParallel(
+        #         model,
+        #         device_ids=[cfg.local_rank],
+        #         output_device=cfg.local_rank,
+        #         find_unused_parameters=find_unused_parameters,
+        #         broadcast_buffers=False)
+        # else:
+        #     delay_allreduce = cfg.trainer.delay_allreduce
+        #     return apex.parallel.DistributedDataParallel(
+        #         model, delay_allreduce=delay_allreduce)
+    else:
+        return WrappedModel(model)
+
+
+def get_scheduler(cfg_opt, opt):
+    """Return the scheduler object.
+
+    Args:
+        cfg_opt (obj): Config for the specific optimization module (gen/dis).
+        opt (obj): PyTorch optimizer object.
+
+    Returns:
+        (obj): Scheduler
+    """
+    if cfg_opt.lr_policy.type == 'step':
+        scheduler = lr_scheduler.StepLR(
+            opt,
+            step_size=cfg_opt.lr_policy.step_size,
+            gamma=cfg_opt.lr_policy.gamma)
+    elif cfg_opt.lr_policy.type == 'constant':
+        scheduler = lr_scheduler.LambdaLR(opt, lambda x: 1)
+    elif cfg_opt.lr_policy.type == 'linear':
+        # Start linear decay from here.
+        decay_start = cfg_opt.lr_policy.decay_start
+        # End linear decay here.
+        # Continue to train using the lowest learning rate till the end.
+        decay_end = cfg_opt.lr_policy.decay_end
+        # Lowest learning rate multiplier.
+        decay_target = cfg_opt.lr_policy.decay_target
+
+        def sch(x):
+            return min(
+                max(((x - decay_start) * decay_target + decay_end - x) / (
+                    decay_end - decay_start
+                ), decay_target), 1.
+            )
+        scheduler = lr_scheduler.LambdaLR(opt, lambda x: sch(x))
+    else:
+        return NotImplementedError('Learning rate policy {} not implemented.'.
+                                   format(cfg_opt.lr_policy.type))
+    return scheduler
+
+
+def get_optimizer(cfg_opt, net):
+    r"""Return the scheduler object.
+
+    Args:
+        cfg_opt (obj): Config for the specific optimization module (gen/dis).
+        net (obj): PyTorch network object.
+
+    Returns:
+        (obj): Pytorch optimizer
+    """
+    if hasattr(net, 'get_param_groups'):
+        # Allow the network to use different hyper-parameters (e.g., learning
+        # rate) for different parameters.
+        params = net.get_param_groups(cfg_opt)
+    else:
+        params = net.parameters()
+    return get_optimizer_for_params(cfg_opt, params)
+
+
+def get_optimizer_for_params(cfg_opt, params):
+    r"""Return the scheduler object.
+
+    Args:
+        cfg_opt (obj): Config for the specific optimization module (gen/dis).
+        params (obj): Parameters to be trained by the parameters.
+
+    Returns:
+        (obj): Optimizer
+    """
+    # We will use fuse optimizers by default.
+    fused_opt = cfg_opt.fused_opt
+    try:
+        from apex.optimizers import FusedAdam
+    except:  # noqa
+        fused_opt = False
+
+    if cfg_opt.type == 'adam':
+        if fused_opt:
+            opt = FusedAdam(params,
+                            lr=cfg_opt.lr, eps=cfg_opt.eps,
+                            betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2))
+        else:
+            opt = Adam(params,
+                       lr=cfg_opt.lr, eps=cfg_opt.eps,
+                       betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2))
+
+    elif cfg_opt.type == 'madam':
+        g_bound = getattr(cfg_opt, 'g_bound', None)
+        opt = Madam(params, lr=cfg_opt.lr,
+                    scale=cfg_opt.scale, g_bound=g_bound)
+    elif cfg_opt.type == 'fromage':
+        opt = Fromage(params, lr=cfg_opt.lr)
+    elif cfg_opt.type == 'rmsprop':
+        opt = RMSprop(params, lr=cfg_opt.lr,
+                      eps=cfg_opt.eps, weight_decay=cfg_opt.weight_decay)
+    elif cfg_opt.type == 'sgd':
+        if fused_opt:
+            from apex.optimizers import FusedSGD
+            opt = FusedSGD(params,
+                           lr=cfg_opt.lr,
+                           momentum=cfg_opt.momentum,
+                           weight_decay=cfg_opt.weight_decay)
+        else:
+            opt = SGD(params,
+                      lr=cfg_opt.lr,
+                      momentum=cfg_opt.momentum,
+                      weight_decay=cfg_opt.weight_decay)
+    else:
+        raise NotImplementedError(
+            'Optimizer {} is not yet implemented.'.format(cfg_opt.type))
+    return opt
diff --git a/imaginaire/utils/visualization/__init__.py b/imaginaire/utils/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..27e3e7c0383a7e5f032593a50930e9d48bd0292b
--- /dev/null
+++ b/imaginaire/utils/visualization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .common import tensor2im, tensor2flow, tensor2label, tensor2pilimage
+from .common import save_tensor_image
+
+__all__ = ['tensor2im', 'tensor2flow', 'tensor2label', 'tensor2pilimage',
+           'save_tensor_image']
diff --git a/imaginaire/utils/visualization/common.py b/imaginaire/utils/visualization/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4b68c5b670c386fe9bef916db13c92682b81bd2
--- /dev/null
+++ b/imaginaire/utils/visualization/common.py
@@ -0,0 +1,314 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import cv2
+import numpy as np
+import PIL
+from PIL import Image
+import torch
+import torchvision
+import os
+
+
+def save_tensor_image(
+        filename, image, minus1to1_normalized=False):
+    r"""Convert a 3 dimensional torch tensor to a PIL image with the desired
+    width and height.
+
+    Args:
+        filename (str): Image filename to be saved to.
+        image (3 x W1 x H1 tensor): Image tensor
+        minus1to1_normalized (bool): True if the tensor values are in [-1,
+        1]. Otherwise, we assume the values are in [0, 1].
+
+    Returns:
+        (PIL image): The resulting PIL image.
+    """
+    if len(image.size()) != 3:
+        raise ValueError('Image tensor dimension does not equal = 3.')
+    if image.size(0) != 3:
+        raise ValueError('Image has more than 3 channels.')
+    if minus1to1_normalized:
+        # Normalize back to [0, 1]
+        image = (image + 1) * 0.5
+    dirname = os.path.dirname(filename)
+    os.makedirs(dirname, exist_ok=True)
+    image_grid = torchvision.utils.make_grid(
+        image, nrow=1, padding=0, normalize=False)
+    torchvision.utils.save_image(image_grid, filename, nrow=1)
+    return
+
+
+def tensor2pilimage(image, width=None, height=None, minus1to1_normalized=False):
+    r"""Convert a 3 dimensional torch tensor to a PIL image with the desired
+    width and height.
+
+    Args:
+        image (3 x W1 x H1 tensor): Image tensor
+        width (int): Desired width for the result PIL image.
+        height (int): Desired height for the result PIL image.
+        minus1to1_normalized (bool): True if the tensor values are in [-1,
+        1]. Otherwise, we assume the values are in [0, 1].
+
+    Returns:
+        (PIL image): The resulting PIL image.
+    """
+    if len(image.size()) != 3:
+        raise ValueError('Image tensor dimension does not equal = 3.')
+    if image.size(0) != 3:
+        raise ValueError('Image has more than 3 channels.')
+    if minus1to1_normalized:
+        # Normalize back to [0, 1]
+        image = (image + 1) * 0.5
+    image = image.detach().cpu().squeeze().numpy()
+    image = np.transpose(image, (1, 2, 0)) * 255
+    output_img = Image.fromarray(np.uint8(image))
+    if width is not None and height is not None:
+        output_img = output_img.resize((width, height), Image.BICUBIC)
+    return output_img
+
+
+def tensor2im(image_tensor, imtype=np.uint8, normalize=True,
+              three_channel_output=True):
+    r"""Convert tensor to image.
+
+    Args:
+        image_tensor (torch.tensor or list of torch.tensor): If tensor then
+            (NxCxHxW) or (NxTxCxHxW) or (CxHxW).
+        imtype (np.dtype): Type of output image.
+        normalize (bool): Is the input image normalized or not?
+            three_channel_output (bool): Should single channel images be made 3
+            channel in output?
+
+    Returns:
+        (numpy.ndarray, list if case 1, 2 above).
+    """
+    if image_tensor is None:
+        return None
+    if isinstance(image_tensor, list):
+        return [tensor2im(x, imtype, normalize) for x in image_tensor]
+    if image_tensor.dim() == 5 or image_tensor.dim() == 4:
+        return [tensor2im(image_tensor[idx], imtype, normalize)
+                for idx in range(image_tensor.size(0))]
+
+    if image_tensor.dim() == 3:
+        image_numpy = image_tensor.cpu().float().numpy()
+        if normalize:
+            image_numpy = (np.transpose(
+                image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+        else:
+            image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
+        image_numpy = np.clip(image_numpy, 0, 255)
+        if image_numpy.shape[2] == 1 and three_channel_output:
+            image_numpy = np.repeat(image_numpy, 3, axis=2)
+        elif image_numpy.shape[2] > 3:
+            image_numpy = image_numpy[:, :, :3]
+        return image_numpy.astype(imtype)
+
+
+def tensor2label(segmap, n_label=None, imtype=np.uint8,
+                 colorize=True, output_normalized_tensor=False):
+    r"""Convert segmentation mask tensor to color image.
+    Args:
+        segmap (tensor) of
+        If tensor then (NxCxHxW) or (NxTxCxHxW) or (CxHxW).
+        n_label (int): If None, then segmap.size(0).
+        imtype (np.dtype): Type of output image.
+        colorize (bool): Put colors in.
+
+    Returns:
+        (numpy.ndarray or normalized torch image).
+    """
+    if segmap is None:
+        return None
+    if isinstance(segmap, list):
+        return [tensor2label(x, n_label,
+                             imtype, colorize,
+                             output_normalized_tensor) for x in segmap]
+    if segmap.dim() == 5 or segmap.dim() == 4:
+        return [tensor2label(segmap[idx], n_label,
+                             imtype, colorize,
+                             output_normalized_tensor)
+                for idx in range(segmap.size(0))]
+
+    segmap = segmap.float()
+    if not output_normalized_tensor:
+        segmap = segmap.cpu()
+    if n_label is None:
+        n_label = segmap.size(0)
+    if n_label > 1:
+        segmap = segmap.max(0, keepdim=True)[1]
+
+    if output_normalized_tensor:
+        if n_label == 0:
+            segmap = Colorize(256)(segmap).to('cuda')
+        else:
+            segmap = Colorize(n_label)(segmap).to('cuda')
+        return 2 * (segmap.float() / 255) - 1
+    else:
+        if colorize:
+            segmap = Colorize(n_label)(segmap)
+            segmap = np.transpose(segmap.numpy(), (1, 2, 0))
+        else:
+            segmap = segmap.cpu().numpy()
+        return segmap.astype(imtype)
+
+
+def tensor2flow(tensor, imtype=np.uint8):
+    r"""Convert flow tensor to color image.
+
+    Args:
+        tensor (tensor) of
+        If tensor then (NxCxHxW) or (NxTxCxHxW) or (CxHxW).
+        imtype (np.dtype): Type of output image.
+
+    Returns:
+        (numpy.ndarray or normalized torch image).
+    """
+    if tensor is None:
+        return None
+    if isinstance(tensor, list):
+        tensor = [t for t in tensor if t is not None]
+        if not tensor:
+            return None
+        return [tensor2flow(t, imtype) for t in tensor]
+    if tensor.dim() == 5 or tensor.dim() == 4:
+        return [tensor2flow(tensor[b]) for b in range(tensor.size(0))]
+
+    tensor = tensor.detach().cpu().float().numpy()
+    tensor = np.transpose(tensor, (1, 2, 0))
+
+    hsv = np.zeros((tensor.shape[0], tensor.shape[1], 3), dtype=imtype)
+    hsv[:, :, 0] = 255
+    hsv[:, :, 1] = 255
+    mag, ang = cv2.cartToPolar(tensor[..., 0], tensor[..., 1])
+    hsv[..., 0] = ang * 180 / np.pi / 2
+    hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
+    rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
+    return rgb
+
+
+def plot_keypoints(image, keypoints, normalize=True):
+    r"""Plot keypoints on image.
+
+    Args:
+       image (PIL.Image, or numpy.ndarray, or torch.Tensor): Input image.
+       keypoints (np.ndarray or torch.Tensor, Nx2): Keypoint locations.
+       normalize (bool): Whether to normalize the image or not.
+    """
+    if isinstance(image, PIL.Image.Image):
+        image = np.array(image)
+    if isinstance(image, torch.Tensor):
+        image = tensor2im(image, normalize=normalize)
+    if isinstance(image, np.ndarray):
+        assert image.ndim == 3
+        assert image.shape[-1] == 1 or image.shape[-1] == 3
+    if isinstance(keypoints, torch.Tensor):
+        keypoints = keypoints.cpu().numpy()
+    assert keypoints.ndim == 2 and keypoints.shape[1] == 2
+
+    cv2_image = np.ascontiguousarray(image[:, :, ::-1])  # RGB to BGR.
+    for idx in range(keypoints.shape[0]):
+        keypoint = np.round(keypoints[idx]).astype(np.int)
+        cv2_image = cv2.circle(cv2_image, tuple(keypoint),
+                               5, (0, 255, 0), -1)
+    image = np.ascontiguousarray(cv2_image[:, :, ::-1])
+    return image
+
+
+def labelcolormap(N):
+    r"""Create colors for segmentation label ids.
+
+    Args:
+        N (int): Number of labels.
+    """
+    if N == 35:  # GTA/cityscape train
+        cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0),
+                         (111, 74, 0), (81, 0, 81), (128, 64, 128),
+                         (244, 35, 232), (250, 170, 160), (230, 150, 140),
+                         (70, 70, 70), (102, 102, 156), (190, 153, 153),
+                         (180, 165, 180), (150, 100, 100), (150, 120, 90),
+                         (153, 153, 153), (153, 153, 153), (250, 170, 30),
+                         (220, 220, 0), (107, 142, 35), (152, 251, 152),
+                         (70, 130, 180), (220, 20, 60), (255, 0, 0),
+                         (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 0, 90),
+                         (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32),
+                         (0, 0, 142)],
+                        dtype=np.uint8)
+    elif N == 20:  # GTA/cityscape eval
+        cmap = np.array([(128, 64, 128), (244, 35, 232), (70, 70, 70),
+                         (102, 102, 156), (190, 153, 153), (153, 153, 153),
+                         (250, 170, 30), (220, 220, 0), (107, 142, 35),
+                         (152, 251, 152), (220, 20, 60), (255, 0, 0),
+                         (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100),
+                         (0, 0, 230), (119, 11, 32), (70, 130, 180), (0, 0, 0)],
+                        dtype=np.uint8)
+    else:
+        cmap = np.zeros([N, 3]).astype(np.uint8)
+        for i in range(N):
+            r, g, b = np.zeros(3)
+            for j in range(8):
+                r = r + (1 << (7 - j)) * ((i & (1 << (3 * j))) >> (3 * j))
+                g = g + (1 << (7 - j)) * \
+                    ((i & (1 << (3 * j + 1))) >> (3 * j + 1))
+                b = b + (1 << (7 - j)) * \
+                    ((i & (1 << (3 * j + 2))) >> (3 * j + 2))
+            cmap[i, :] = np.array([r, g, b])
+    return cmap
+
+
+class Colorize(object):
+    """Class to colorize segmentation maps."""
+
+    def __init__(self, n=35):
+        self.cmap = labelcolormap(n)
+        self.cmap = torch.from_numpy(self.cmap[:n])
+
+    def __call__(self, seg_map):
+        r"""
+
+        Args:
+            seg_map (tensor): Input Segmentation maps to be colorized.
+        """
+        size = seg_map.size()
+        color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
+        for label in range(0, len(self.cmap)):
+            mask = (label == seg_map[0]).cpu()
+            color_image[0][mask] = self.cmap[label][0]
+            color_image[1][mask] = self.cmap[label][1]
+            color_image[2][mask] = self.cmap[label][2]
+        return color_image
+
+
+def plot_keypoints_on_black(resize_h, resize_w, crop_h, crop_w, is_flipped,
+                            cfgdata, keypoints):
+    r"""Plot keypoints on black image.
+
+    Args:
+        resize_h (int): Height to be resized to.
+        resize_w (int): Width to be resized to.
+        crop_h (int): Height of the cropping.
+        crop_w (int): Width of the cropping.
+        is_flipped (bool): If image is a flipped version.
+        cfgdata (obj): Data configuration object.
+        keypoints (np.ndarray): Keypoint locations. Shape of
+            (Nx2) or (TxNx2).
+
+    Returns:
+        (list of np.ndarray): List of images (output_h, output_w, 3).
+    """
+    if keypoints.ndim == 2 and keypoints.shape[1] == 2:
+        keypoints = keypoints[np.newaxis, ...]
+
+    outputs = []
+    for t_idx in range(keypoints.shape[0]):
+        cv2_image = np.zeros((crop_h, crop_w, 3)).astype(np.uint8)
+        for idx in range(keypoints[t_idx].shape[0]):
+            keypoint = np.round(keypoints[t_idx][idx]).astype(np.int)
+            cv2_image = cv2.circle(cv2_image, tuple(keypoint),
+                                   5, (0, 255, 0), -1)
+        image = np.ascontiguousarray(cv2_image[:, :, ::-1])  # BGR to RGB.
+        outputs.append(image)
+
+    return outputs
diff --git a/imaginaire/utils/visualization/face.py b/imaginaire/utils/visualization/face.py
new file mode 100644
index 0000000000000000000000000000000000000000..19728cda34ddf75552fab9c421a6d76af7983542
--- /dev/null
+++ b/imaginaire/utils/visualization/face.py
@@ -0,0 +1,491 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import cv2
+import torch
+from scipy.optimize import curve_fit
+from scipy.signal import medfilt
+import warnings
+from imaginaire.utils.io import get_checkpoint
+
+
+def connect_face_keypoints(resize_h, resize_w, crop_h, crop_w, original_h,
+                           original_w, is_flipped, cfgdata, keypoints):
+    r"""Connect the face keypoints to edges and draw the sketch.
+
+    Args:
+        resize_h (int): Height the input image was resized to.
+        resize_w (int): Width the input image was resized to.
+        crop_h (int): Height the input image was cropped.
+        crop_w (int): Width the input image was cropped.
+        original_h (int): Original height of the input image.
+        original_w (int): Original width of the input image.
+        is_flipped (bool): Is the input image flipped.
+        cfgdata (obj): Data configuration.
+        keypoints (NxKx2 numpy array): Facial landmarks (with K keypoints).
+
+    Returns:
+        (list of HxWxC numpy array): Drawn label map.
+    """
+    if hasattr(cfgdata, 'for_face_dataset'):
+        face_cfg = cfgdata.for_face_dataset
+        # Whether to add the upper part of face to label map.
+        add_upper_face = getattr(face_cfg, 'add_upper_face', False)
+        # Whether to add distance transform output to label map.
+        add_dist_map = getattr(face_cfg, 'add_distance_transform', False)
+        # Whether to add positional encoding to label map.
+        add_pos_encode = add_dist_map and getattr(
+            face_cfg, 'add_positional_encode', False)
+    else:
+        add_upper_face = add_dist_map = add_pos_encode = False
+
+    # Mapping from keypoint index to facial part.
+    part_list = [[list(range(0, 17)) + (
+        (list(range(68, 83)) + [0]) if add_upper_face else [])],  # ai_emoji
+                      [range(17, 22)],  # right eyebrow
+                      [range(22, 27)],  # left eyebrow
+                      [[28, 31], range(31, 36), [35, 28]],  # nose
+                      [[36, 37, 38, 39], [39, 40, 41, 36]],  # right eye
+                      [[42, 43, 44, 45], [45, 46, 47, 42]],  # left eye
+                      [range(48, 55), [54, 55, 56, 57, 58, 59, 48],
+                       range(60, 65), [64, 65, 66, 67, 60]],  # mouth and tongue
+    ]
+    if add_upper_face:
+        pts = keypoints[:, :17, :].astype(np.int32)
+        baseline_y = (pts[:, 0:1, 1] + pts[:, -1:, 1]) / 2
+        upper_pts = pts[:, 1:-1, :].copy()
+        upper_pts[:, :, 1] = baseline_y + (
+                baseline_y - upper_pts[:, :, 1]) * 2 // 3
+        keypoints = np.hstack((keypoints, upper_pts[:, ::-1, :]))
+
+    edge_len = 3  # Interpolate 3 keypoints to form a curve when drawing edges.
+    bw = max(1, resize_h // 256)  # Width of the stroke.
+
+    outputs = []
+    for t_idx in range(keypoints.shape[0]):
+        # Edge map for the face region from keypoints.
+        im_edges = np.zeros((resize_h, resize_w, 1), np.uint8)
+        im_dists = np.zeros((resize_h, resize_w, 0), np.uint8)
+        for edge_list in part_list:
+            for e, edge in enumerate(edge_list):
+                # Edge map for the current edge.
+                im_edge = np.zeros((resize_h, resize_w, 1), np.uint8)
+                # Divide a long edge into multiple small edges when drawing.
+                for i in range(0, max(1, len(edge) - 1), edge_len - 1):
+                    sub_edge = edge[i:i + edge_len]
+                    x = keypoints[t_idx, sub_edge, 0]
+                    y = keypoints[t_idx, sub_edge, 1]
+
+                    # Interp keypoints to get the curve shape.
+                    curve_x, curve_y = interp_points(x, y)
+                    draw_edge(im_edges, curve_x, curve_y, bw=bw)
+                    if add_dist_map:
+                        draw_edge(im_edge, curve_x, curve_y, bw=bw)
+
+                if add_dist_map:
+                    # Add distance transform map on each facial part.
+                    im_dist = cv2.distanceTransform(255 - im_edge,
+                                                    cv2.DIST_L1, 3)
+                    im_dist = np.clip((im_dist / 3), 0, 255)
+                    im_dists = np.dstack((im_dists, im_dist))
+
+                if add_pos_encode and e == 0:
+                    # Add positional encoding for the first edge.
+                    from math import pi
+                    im_pos = np.zeros((resize_h, resize_w, 0), np.float32)
+                    for l in range(10):  # noqa: E741
+                        dist = (im_dist.astype(np.float32) - 127.5) / 127.5
+                        sin = np.sin(pi * (2 ** l) * dist)
+                        cos = np.cos(pi * (2 ** l) * dist)
+                        im_pos = np.dstack((im_pos, sin, cos))
+
+        # Combine all components to form the final label map.
+        if add_dist_map:
+            im_edges = np.dstack((im_edges, im_dists))
+        im_edges = im_edges.astype(np.float32) / 255.0
+        if add_pos_encode:
+            im_edges = np.dstack((im_edges, im_pos))
+        outputs.append(im_edges)
+    return outputs
+
+
+def normalize_and_connect_face_keypoints(cfg, is_inference, data):
+    r"""Normalize face keypoints w.r.t. reference face keypoints and connect
+    keypoints to form 2D images.
+
+    Args:
+        cfg (obj): Data configuration.
+        is_inference (bool): Is doing inference or not.
+        data (dict): Input data.
+
+    Returns:
+        (dict): Output data.
+    """
+    assert is_inference
+    resize_h, resize_w = data['images'][0].shape[-2:]
+
+    keypoints = data['label'].numpy()[0]
+    ref_keypoints = data['few_shot_label'].numpy()[0]
+
+    # Get the normalization params and prev data if it's been computed before.
+    dist_scales = prev_keypoints = None
+    if 'common_attr' in data and 'prev_data' in data['common_attr']:
+        dist_scales = data['common_attr']['dist_scales']
+        prev_keypoints = data['common_attr']['prev_data']
+
+    def concat(prev, now, t):
+        r"""Concat prev and now frames in first dimension, up to t frames."""
+        if prev is None:
+            return now
+        return np.vstack([prev, now])[-t:]
+
+    # Normalize face keypoints w.r.t. reference face keypoints.
+    keypoints, dist_scales = \
+        normalize_face_keypoints(keypoints[0], ref_keypoints[0], dist_scales,
+                                 momentum=getattr(cfg.for_face_dataset,
+                                                  'normalize_momentum', 0.9))
+    keypoints = keypoints[np.newaxis, :]
+
+    # Temporally smooth the face keypoints by median filtering.
+    ks = getattr(cfg.for_face_dataset, 'smooth_kernel_size', 5)
+    concat_keypoints = concat(prev_keypoints, keypoints, ks)
+    if ks > 1 and concat_keypoints.shape[0] == ks:
+        keypoints = smooth_face_keypoints(concat_keypoints, ks)
+
+    # Store the computed params.
+    if 'common_attr' not in data:
+        data['common_attr'] = dict()
+    data['common_attr']['dist_scales'] = dist_scales
+    data['common_attr']['prev_data'] = concat_keypoints
+
+    # Draw the keypoints to turn them into images.
+    labels = []
+    for kpt in [keypoints, ref_keypoints]:
+        label = connect_face_keypoints(resize_h, resize_w, None, None, None,
+                                       None, False, cfg, kpt)
+        labels += [torch.from_numpy(label[0]).permute(2, 0, 1).unsqueeze(0)]
+    data['label'], data['few_shot_label'] = labels
+    return data
+
+
+def smooth_face_keypoints(concat_keypoints, ks):
+    r""" Temporally smooth the face keypoints by median filtering.
+
+    Args:
+        concat_keypoints (TxKx2 numpy array): Face keypoints to be filtered.
+        ks (int): Filter kernel size.
+
+    Returns:
+        (1xKx2 numpy array): Output face keypoints.
+    """
+    # Median filtering.
+    filtered_keypoints = medfilt(concat_keypoints, kernel_size=[ks, 1, 1])
+    # Fill in any zero keypoints with the value from previous frame.
+    if (filtered_keypoints == 0).any():
+        for t in range(1, filtered_keypoints.shape[0]):
+            kpt_prev = filtered_keypoints[t - 1]
+            kpt_cur = filtered_keypoints[t]
+            kpt_max = np.maximum(kpt_cur, kpt_prev)
+            kpt_cur[kpt_cur == 0] = kpt_max[kpt_cur == 0]
+            filtered_keypoints[t] = kpt_cur
+    keypoints = filtered_keypoints[ks // 2: ks // 2 + 1]
+    return keypoints
+
+
+def normalize_face_keypoints(keypoints, ref_keypoints, dist_scales=None,
+                             momentum=0.9):
+    r"""Normalize face keypoints w.r.t. the reference face keypoints.
+
+    Args:
+        keypoints (Kx2 numpy array): Target facial keypoints to be normalized.
+        ref_keypoints (Kx2 numpy array): Reference facial keypoints.
+        dist_scales (list of list of floats): Normalization params.
+        momentum (float): Temporal momentum for the normalization params.
+
+    Returns:
+        (Kx2 numpy array): Normalized facial keypoints.
+    """
+    if keypoints.shape[0] == 68:
+        central_keypoints = [8]
+        part_list = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12],
+                     [5, 11], [6, 10], [7, 9, 8],
+                     [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
+                     [27], [28], [29], [30], [31, 35], [32, 34], [33],
+                     [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
+                     [48, 54], [49, 53], [50, 52], [51], [55, 59], [56, 58],
+                     [57],
+                     [60, 64], [61, 63], [62], [65, 67], [66]
+                     ]
+    else:
+        raise ValueError('Input keypoints type not supported.')
+
+    face_cen = np.mean(keypoints[central_keypoints, :], axis=0)
+    ref_face_cen = np.mean(ref_keypoints[central_keypoints, :], axis=0)
+
+    def get_mean_dists(pts, face_cen):
+        r"""Get mean distances of the points from face center."""
+        mean_dists_x, mean_dists_y = [], []
+        pts_cen = np.mean(pts, axis=0)
+        for p, pt in enumerate(pts):
+            mean_dists_x.append(np.linalg.norm(pt - pts_cen))
+            mean_dists_y.append(np.linalg.norm(pts_cen - face_cen))
+        mean_dist_x = sum(mean_dists_x) / len(mean_dists_x) + 1e-3
+        mean_dist_y = sum(mean_dists_y) / len(mean_dists_y) + 1e-3
+        return mean_dist_x, mean_dist_y
+
+    dist_scale_x, dist_scale_y = [None] * len(part_list), \
+                                 [None] * len(part_list)
+    if dist_scales is None:
+        dist_scale_x_prev = dist_scale_y_prev = img_scale = None
+    else:
+        dist_scale_x_prev, dist_scale_y_prev, img_scale = dist_scales
+    if img_scale is None:
+        img_scale = (keypoints[:, 0].max() - keypoints[:, 0].min()) \
+                    / (ref_keypoints[:, 0].max() - ref_keypoints[:, 0].min())
+
+    for i, pts_idx in enumerate(part_list):
+        pts = keypoints[pts_idx]
+        pts = pts[pts[:, 0] != 0]
+        if pts.shape[0]:
+            ref_pts = ref_keypoints[pts_idx]
+            mean_dist_x, mean_dist_y = get_mean_dists(pts, face_cen)
+            ref_dist_x, ref_dist_y = get_mean_dists(ref_pts, ref_face_cen)
+            dist_scale_x[i] = ref_dist_x / mean_dist_x * img_scale
+            dist_scale_y[i] = ref_dist_y / mean_dist_y * img_scale
+            if dist_scale_x_prev is not None:
+                dist_scale_x[i] = dist_scale_x_prev[i] * momentum + \
+                    dist_scale_x[i] * (1 - momentum)
+                dist_scale_y[i] = dist_scale_y_prev[i] * momentum + \
+                    dist_scale_y[i] * (1 - momentum)
+
+            pts_cen = np.mean(pts, axis=0)
+            pts = (pts - pts_cen) * dist_scale_x[i] + \
+                  (pts_cen - face_cen) * dist_scale_y[i] + face_cen
+            keypoints[pts_idx] = pts
+
+    return keypoints, [dist_scale_x, dist_scale_y, img_scale]
+
+
+def npy_to_tensor(keypoints):
+    r"""Convert numpy array to pytorch tensor."""
+    return torch.from_numpy(keypoints).unsqueeze(0)
+
+
+def get_dlib_landmarks_from_image(
+        imgs, predictor_path='shape_predictor_68_face_landmarks.dat'):
+    r"""Get face keypoints from an image.
+
+    Args:
+        imgs (N x 3 x H x W tensor or N x H x W x 3 numpy array): Input images.
+        predictor_path (str): Path to the predictor model.
+    """
+    import dlib
+    predictor_path = get_checkpoint(predictor_path,
+                                    url='1l9zT-AI1yKlfyAb_wl_RjLBSaiWQr8dr')
+    if type(imgs) == torch.Tensor:
+        imgs = ((imgs + 1) / 2 * 255).byte()
+        imgs = np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1))
+    detector = dlib.get_frontal_face_detector()
+    predictor = dlib.shape_predictor(predictor_path)
+    points = np.zeros([imgs.shape[0], 68, 2], dtype=int)
+    for i in range(imgs.shape[0]):
+        img = imgs[i]
+        dets = detector(img, 1)
+        if len(dets) > 0:
+            # Only returns the first face.
+            shape = predictor(img, dets[0])
+            for b in range(68):
+                points[i, b, 0] = shape.part(b).x
+                points[i, b, 1] = shape.part(b).y
+    return points
+
+
+def get_126_landmarks_from_image(imgs, landmarks_network):
+    r"""Get face keypoints from an image.
+
+    Args:
+        imgs (Nx3xHxW tensor or NxHxWx3 numpy array):
+        Input images.
+        landmarks_network (obj): The landmark detection network.
+
+    Return:
+        (Nx126x2 numpy array): Predicted landmarks.
+    """
+    if type(imgs) == torch.Tensor:
+        imgs = ((imgs + 1) / 2 * 255).byte()
+        imgs = np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1))
+
+    landmarks = []
+    for i in range(imgs.shape[0]):
+        img = imgs[i]
+        out_boxes, landmark = \
+            landmarks_network.get_face_boxes_and_landmarks(img)
+        if len(landmark) > 1:
+            # Pick the largest face in the image.
+            face_size_max = face_index = 0
+            for i, out_box in enumerate(out_boxes):
+                face_size = max(out_box[2] - out_box[0],
+                                out_box[1] - out_box[1])
+                if face_size > face_size_max:
+                    face_size_max = face_size
+                    face_index = i
+            landmark = landmark[face_index]
+        elif len(landmark) == 1:
+            landmark = landmark[0]
+        else:
+            landmark = np.zeros((126, 2), dtype=np.float32)
+        landmarks += [landmark[np.newaxis]]
+    landmarks = np.vstack(landmarks).astype(np.float32)
+    return landmarks
+
+
+def convert_face_landmarks_to_image(cfgdata, landmarks, output_size,
+                                    output_tensor=True, cpu_only=False):
+    r"""Convert the facial landmarks to a label map.
+
+    Args:
+        cfgdata (obj): Data configuration.
+        landmarks
+        output_size (tuple of int): H, W of output label map.
+        output_tensor (bool): Output tensors instead of numpy arrays.
+        cpu_only (bool): Output CPU tensor only.
+
+    Returns:
+        (NxCxHxW tensor or list of HxWxC numpy arrays): Label maps.
+    """
+    h, w = output_size
+    labels = connect_face_keypoints(h, w, None, None, None, None, False,
+                                    cfgdata, landmarks)
+    if not output_tensor:
+        return labels
+    labels = [torch.from_numpy(label).permute(2, 0, 1).unsqueeze(0)
+              for label in labels]
+    labels = torch.cat(labels)
+    if cpu_only:
+        return labels
+    return labels.cuda()
+
+
+def add_face_keypoints(label_map, image, keypoints):
+    r"""Add additional keypoints to label map.
+
+    Args:
+        label_map (Nx1xHxW tensor or None)
+        image (Nx3xHxW tensor)
+        keypoints (NxKx2 tensor)
+    """
+    if label_map is None:
+        label_map = torch.zeros_like(image)[:, :1]
+    x, y = keypoints[:, :, 0], keypoints[:, :, 1]
+    h, w = image.shape[-2:]
+    x = ((x + 1) / 2 * w).long()
+    y = ((y + 1) / 2 * h).long()
+    bs = torch.arange(label_map.shape[0]).cuda().view(-1, 1).expand_as(x)
+    label_map[bs, :, y, x] = 1
+    return label_map
+
+
+def draw_edge(im, x, y, bw=1, color=(255, 255, 255), draw_end_points=False):
+    r"""Set colors given a list of x and y coordinates for the edge.
+
+    Args:
+        im (HxWxC numpy array): Canvas to draw.
+        x (1D numpy array): x coordinates of the edge.
+        y (1D numpy array): y coordinates of the edge.
+        bw (int): Width of the stroke.
+        color (list or tuple of int): Color to draw.
+        draw_end_points (bool): Whether to draw end points of the edge.
+    """
+    if x is not None and x.size:
+        h, w = im.shape[0], im.shape[1]
+        # Draw edge.
+        for i in range(-bw, bw):
+            for j in range(-bw, bw):
+                yy = np.maximum(0, np.minimum(h - 1, y + i))
+                xx = np.maximum(0, np.minimum(w - 1, x + j))
+                set_color(im, yy, xx, color)
+
+        # Draw endpoints.
+        if draw_end_points:
+            for i in range(-bw * 2, bw * 2):
+                for j in range(-bw * 2, bw * 2):
+                    if (i ** 2) + (j ** 2) < (4 * bw ** 2):
+                        yy = np.maximum(0, np.minimum(h - 1, np.array(
+                            [y[0], y[-1]]) + i))
+                        xx = np.maximum(0, np.minimum(w - 1, np.array(
+                            [x[0], x[-1]]) + j))
+                        set_color(im, yy, xx, color)
+
+
+def set_color(im, yy, xx, color):
+    r"""Set pixels of the image to the given color.
+
+    Args:
+        im (HxWxC numpy array): Canvas to draw.
+        xx (1D numpy array): x coordinates of the pixels.
+        yy (1D numpy array): y coordinates of the pixels.
+        color (list or tuple of int): Color to draw.
+    """
+    if type(color) != list and type(color) != tuple:
+        color = [color] * 3
+    if len(im.shape) == 3 and im.shape[2] == 3:
+        if (im[yy, xx] == 0).all():
+            im[yy, xx, 0], im[yy, xx, 1], im[yy, xx, 2] = \
+                color[0], color[1], color[2]
+        else:
+            for c in range(3):
+                im[yy, xx, c] = ((im[yy, xx, c].astype(float)
+                                  + color[c]) / 2).astype(np.uint8)
+    else:
+        im[yy, xx] = color[0]
+
+
+def interp_points(x, y):
+    r"""Given the start and end points, interpolate to get a curve/line.
+
+    Args:
+        x (1D array): x coordinates of the points to interpolate.
+        y (1D array): y coordinates of the points to interpolate.
+
+    Returns:
+        (dict):
+          - curve_x (1D array): x coordinates of the interpolated points.
+          - curve_y (1D array): y coordinates of the interpolated points.
+    """
+    if abs(x[:-1] - x[1:]).max() < abs(y[:-1] - y[1:]).max():
+        curve_y, curve_x = interp_points(y, x)
+        if curve_y is None:
+            return None, None
+    else:
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            try:
+                if len(x) < 3:
+                    popt, _ = curve_fit(linear, x, y)
+                else:
+                    popt, _ = curve_fit(func, x, y)
+                    if abs(popt[0]) > 1:
+                        return None, None
+            except Exception:
+                return None, None
+        if x[0] > x[-1]:
+            x = list(reversed(x))
+            y = list(reversed(y))
+        curve_x = np.linspace(x[0], x[-1], int(np.round(x[-1]-x[0])))
+        if len(x) < 3:
+            curve_y = linear(curve_x, *popt)
+        else:
+            curve_y = func(curve_x, *popt)
+    return curve_x.astype(int), curve_y.astype(int)
+
+
+def func(x, a, b, c):
+    r"""Quadratic fitting function."""
+    return a * x**2 + b * x + c
+
+
+def linear(x, a, b):
+    r"""Linear fitting function."""
+    return a * x + b
diff --git a/imaginaire/utils/visualization/pose.py b/imaginaire/utils/visualization/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..bca83f971e5f7c7d0206013bb42de278814244ba
--- /dev/null
+++ b/imaginaire/utils/visualization/pose.py
@@ -0,0 +1,409 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import random
+import importlib
+from .common import tensor2im, tensor2label
+from .face import draw_edge, interp_points
+from imaginaire.model_utils.fs_vid2vid import extract_valid_pose_labels
+
+
+def draw_openpose_npy(resize_h, resize_w, crop_h, crop_w, original_h,
+                      original_w, is_flipped, cfgdata, keypoints_npy):
+    r"""Connect the OpenPose keypoints to edges and draw the pose map.
+
+    Args:
+        resize_h (int): Height the input image was resized to.
+        resize_w (int): Width the input image was resized to.
+        crop_h (int): Height the input image was cropped.
+        crop_w (int): Width the input image was cropped.
+        original_h (int): Original height of the input image.
+        original_w (int): Original width of the input image.
+        is_flipped (bool): Is the input image flipped.
+        cfgdata (obj): Data configuration.
+        keypoints_npy (dict): OpenPose keypoint dict.
+
+    Returns:
+        (list of HxWxC numpy array): Drawn label map.
+    """
+    pose_cfg = cfgdata.for_pose_dataset
+    # Whether to draw only the basic keypoints.
+    basic_points_only = getattr(pose_cfg, 'basic_points_only', False)
+    # Whether to remove the face labels to avoid overfitting.
+    remove_face_labels = getattr(pose_cfg, 'remove_face_labels', False)
+    # Whether to randomly drop some keypoints to avoid overfitting.
+    random_drop_prob = getattr(pose_cfg, 'random_drop_prob', 0)
+
+    # Get the list of edges to draw.
+    edge_lists = define_edge_lists(basic_points_only)
+    op_key = cfgdata.keypoint_data_types[0]
+    for input_type in cfgdata.input_types:
+        if op_key in input_type:
+            nc = input_type[op_key].num_channels
+    if crop_h is not None:
+        h, w = crop_h, crop_w
+    else:
+        h, w = resize_h, resize_w
+
+    outputs = []
+    for keypoint_npy in keypoints_npy:
+        person_keypoints = np.asarray(keypoint_npy).reshape(-1, 137, 3)[0]
+        # Separate out the keypoint array to different parts.
+        pose_pts = person_keypoints[:25]
+        face_pts = person_keypoints[25: (25 + 70)]
+        hand_pts_l = person_keypoints[(25 + 70): (25 + 70 + 21)]
+        hand_pts_r = person_keypoints[-21:]
+        all_pts = [pose_pts, face_pts, hand_pts_l, hand_pts_r]
+        # Remove the keypoints with low confidence.
+        all_pts = [extract_valid_keypoints(pts, edge_lists)
+                   for pts in all_pts]
+
+        # Connect the keypoints to form the label map.
+        pose_img = connect_pose_keypoints(all_pts, edge_lists,
+                                          (h, w, nc),
+                                          basic_points_only,
+                                          remove_face_labels,
+                                          random_drop_prob)
+        pose_img = pose_img.astype(np.float32) / 255.0
+        outputs.append(pose_img)
+    return outputs
+
+
+def openpose_to_npy_largest_only(inputs):
+    r"""Convert OpenPose dicts to numpy arrays of keypoints. Only return the
+    largest/tallest person in each dict.
+
+    Args:
+        inputs (list of dicts): List of OpenPose dicts.
+
+    Returns:
+        (list of numpy arrays): Keypoints.
+    """
+    return base_openpose_to_npy(inputs, return_largest_only=True)
+
+
+def openpose_to_npy(inputs):
+    r"""Conver OpenPose dicts to numpy arrays of keypoints.
+
+    Args:
+        inputs (list of dicts): List of OpenPose dicts.
+
+    Returns:
+        (list of numpy arrays): Keypoints.
+    """
+    return base_openpose_to_npy(inputs, return_largest_only=False)
+
+
+def base_openpose_to_npy(inputs, return_largest_only=False):
+    r"""Convert OpenPose dicts to numpy arrays of keypoints.
+
+    Args:
+        inputs (list of dicts): List of OpenPose dicts.
+        return_largest_only (bool): Whether to return only the largest person.
+
+    Returns:
+        (list of numpy arrays): Keypoints.
+    """
+    outputs_npy = []
+    for input in inputs:
+        people_dict = input['people']
+        n_ppl = max(1, len(people_dict))
+        output_npy = np.zeros((n_ppl, 25 + 70 + 21 + 21, 3), dtype=np.float32)
+        y_len_max = 0
+        for i, person_dict in enumerate(people_dict):
+            # Extract corresponding keypoints from the dict.
+            pose_pts = np.array(person_dict["pose_keypoints_2d"]).reshape(25, 3)
+            face_pts = np.array(person_dict["face_keypoints_2d"]).reshape(70, 3)
+            hand_pts_l = np.array(person_dict["hand_left_keypoints_2d"]
+                                  ).reshape(21, 3)
+            hand_pts_r = np.array(person_dict["hand_right_keypoints_2d"]
+                                  ).reshape(21, 3)
+
+            if return_largest_only:
+                # Get the body length.
+                y = pose_pts[pose_pts[:, 2] > 0.01, 1]
+                y_len = y.max() - y.min()
+                if y_len > y_len_max:
+                    y_len_max = y_len
+                    max_ind = i
+
+            # Concatenate all keypoint together.
+            output_npy[i] = np.vstack([pose_pts, face_pts,
+                                       hand_pts_l, hand_pts_r])
+        if return_largest_only:
+            # Only return the largest person in the dict.
+            output_npy = output_npy[max_ind: max_ind + 1]
+
+        outputs_npy += [output_npy.astype(np.float32)]
+    return outputs_npy
+
+
+def extract_valid_keypoints(pts, edge_lists):
+    r"""Use only the valid keypoints by looking at the detection confidences.
+    If the confidences for all keypoints in an edge are above threshold,
+    keep the keypoints. Otherwise, their coordinates will be set to zero.
+
+    Args:
+        pts (Px3 numpy array): Keypoint xy coordinates + confidence.
+        edge_lists (nested list of ints):  List of keypoint indices for edges.
+
+    Returns:
+        (Px2 numpy array): Output keypoints.
+    """
+    pose_edge_list, _, hand_edge_list, _, face_list = edge_lists
+    p = pts.shape[0]
+    thre = 0.1 if p == 70 else 0.01
+    output = np.zeros((p, 2))
+
+    if p == 70:  # ai_emoji
+        for edge_list in face_list:
+            for edge in edge_list:
+                if (pts[edge, 2] > thre).all():
+                    output[edge, :] = pts[edge, :2]
+    elif p == 21:  # hand
+        for edge in hand_edge_list:
+            if (pts[edge, 2] > thre).all():
+                output[edge, :] = pts[edge, :2]
+    else:  # pose
+        valid = (pts[:, 2] > thre)
+        output[valid, :] = pts[valid, :2]
+
+    return output
+
+
+def connect_pose_keypoints(pts, edge_lists, size, basic_points_only,
+                           remove_face_labels, random_drop_prob):
+    r"""Draw edges by connecting the keypoints onto the label map.
+
+    Args:
+        pts (Px3 numpy array): Keypoint xy coordinates + confidence.
+        edge_lists (nested list of ints):  List of keypoint indices for edges.
+        size (tuple of int): Output size.
+        basic_points_only (bool): Whether to use only the basic keypoints.
+        remove_face_labels (bool): Whether to remove face labels.
+        random_drop_prob (float): Probability to randomly drop keypoints.
+
+    Returns:
+        (HxWxC numpy array): Output label map.
+    """
+    pose_pts, face_pts, hand_pts_l, hand_pts_r = pts
+    h, w, c = size
+    body_edges = np.zeros((h, w, c), np.uint8)
+    # If using one-hot, different parts of the body will be drawn to
+    # different channels.
+    use_one_hot = c > 3
+    if use_one_hot:
+        assert c == 27
+    pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, \
+        face_list = edge_lists
+
+    # Draw pose edges.
+    h = int(pose_pts[:, 1].max() - pose_pts[:, 1].min())
+    bw = max(1, h // 150)  # Stroke width.
+    body_edges = draw_edges(body_edges, pose_pts, [pose_edge_list], bw,
+                            use_one_hot, random_drop_prob,
+                            colors=pose_color_list, draw_end_points=True)
+
+    if not basic_points_only:
+        # Draw hand edges.
+        bw = max(1, h // 450)
+        for i, hand_pts in enumerate([hand_pts_l, hand_pts_r]):
+            if use_one_hot:
+                k = 24 + i
+                body_edges[:, :, k] = draw_edges(body_edges[:, :, k], hand_pts,
+                                                 [hand_edge_list],
+                                                 bw, False, random_drop_prob,
+                                                 colors=[255] * len(hand_pts))
+            else:
+                body_edges = draw_edges(body_edges, hand_pts, [hand_edge_list],
+                                        bw, False, random_drop_prob,
+                                        colors=hand_color_list)
+        # Draw face edges.
+        if not remove_face_labels:
+            if use_one_hot:
+                k = 26
+                body_edges[:, :, k] = draw_edges(body_edges[:, :, k], face_pts,
+                                                 face_list, bw, False,
+                                                 random_drop_prob)
+            else:
+                body_edges = draw_edges(body_edges, face_pts, face_list, bw,
+                                        False, random_drop_prob)
+    return body_edges
+
+
+def draw_edges(canvas, keypoints, edges_list, bw, use_one_hot,
+               random_drop_prob=0, edge_len=2, colors=None,
+               draw_end_points=False):
+    r"""Draw all the edges in the edge list on the canvas.
+
+    Args:
+        canvas (HxWxK numpy array): Canvas to draw.
+        keypoints (Px2 numpy array): Keypoints.
+        edge_list (nested list of ints):  List of keypoint indices for edges.
+        bw (int): Stroke width.
+        use_one_hot (bool): Use one-hot encoding or not.
+        random_drop_prob (float): Probability to randomly drop keypoints.
+        edge_len (int): Number of keypoints in an edge.
+        colors (tuple of int): Color to draw.
+        draw_end_points (bool): Whether to draw end points for edges.
+
+    Returns:
+        (HxWxK numpy array): Output.
+    """
+    k = 0
+    for edge_list in edges_list:
+        for i, edge in enumerate(edge_list):
+            for j in range(0, max(1, len(edge) - 1), edge_len - 1):
+                if random.random() > random_drop_prob:
+                    sub_edge = edge[j:j + edge_len]
+                    x, y = keypoints[sub_edge, 0], keypoints[sub_edge, 1]
+                    if 0 not in x:  # Get rid of invalid keypoints.
+                        curve_x, curve_y = interp_points(x, y)
+                        if use_one_hot:
+                            # If using one-hot, draw to different channels of
+                            # the canvas.
+                            draw_edge(canvas[:, :, k], curve_x, curve_y,
+                                      bw=bw, color=255,
+                                      draw_end_points=draw_end_points)
+                        else:
+                            color = colors[i] if colors is not None \
+                                else (255, 255, 255)
+                            draw_edge(canvas, curve_x, curve_y,
+                                      bw=bw, color=color,
+                                      draw_end_points=draw_end_points)
+                k += 1
+    return canvas
+
+
+def define_edge_lists(basic_points_only):
+    r"""Define the list of keypoints that should be connected to form the edges.
+
+    Args:
+        basic_points_only (bool): Whether to use only the basic keypoints.
+    """
+    # Pose edges and corresponding colors.
+    pose_edge_list = [
+        [17, 15], [15, 0], [0, 16], [16, 18],  # head
+        [0, 1], [1, 8],                        # body
+        [1, 2], [2, 3], [3, 4],                # right arm
+        [1, 5], [5, 6], [6, 7],                # left arm
+        [8, 9], [9, 10], [10, 11],             # right leg
+        [8, 12], [12, 13], [13, 14]            # left leg
+    ]
+    pose_color_list = [
+        [153, 0, 153], [153, 0, 102], [102, 0, 153], [51, 0, 153],
+        [153, 0, 51], [153, 0, 0],
+        [153, 51, 0], [153, 102, 0], [153, 153, 0],
+        [102, 153, 0], [51, 153, 0], [0, 153, 0],
+        [0, 153, 51], [0, 153, 102], [0, 153, 153],
+        [0, 102, 153], [0, 51, 153], [0, 0, 153],
+    ]
+
+    if not basic_points_only:
+        pose_edge_list += [
+            [11, 24], [11, 22], [22, 23],  # right foot
+            [14, 21], [14, 19], [19, 20]   # left foot
+        ]
+        pose_color_list += [
+            [0, 153, 153], [0, 153, 153], [0, 153, 153],
+            [0, 0, 153], [0, 0, 153], [0, 0, 153]
+        ]
+
+    # Hand edges and corresponding colors.
+    hand_edge_list = [
+        [0, 1, 2, 3, 4],
+        [0, 5, 6, 7, 8],
+        [0, 9, 10, 11, 12],
+        [0, 13, 14, 15, 16],
+        [0, 17, 18, 19, 20]
+    ]
+    hand_color_list = [
+        [204, 0, 0], [163, 204, 0], [0, 204, 82], [0, 82, 204], [163, 0, 204]
+    ]
+
+    # Face edges.
+    face_list = [
+        [range(0, 17)],   # face contour
+        [range(17, 22)],  # left eyebrow
+        [range(22, 27)],  # right eyebrow
+        [[28, 31], range(31, 36), [35, 28]],   # nose
+        [[36, 37, 38, 39], [39, 40, 41, 36]],  # left eye
+        [[42, 43, 44, 45], [45, 46, 47, 42]],  # right eye
+        [range(48, 55), [54, 55, 56, 57, 58, 59, 48]],  # mouth
+    ]
+
+    return pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, \
+        face_list
+
+
+def tensor2pose(cfg, label_tensor):
+    r"""Convert output tensor to a numpy pose map.
+
+    Args:
+        label_tensor (3D/4D/5D tensor): Label tensor.
+
+    Returns:
+        (HxWx3 numpy array or list of numpy arrays): Pose map.
+    """
+    if label_tensor.dim() == 5 or label_tensor.dim() == 4:
+        return [tensor2pose(cfg, label_tensor[idx])
+                for idx in range(label_tensor.size(0))]
+
+    # If adding additional discriminators, draw the bbox for the regions
+    # (e.g. faces) too.
+    add_dis_cfg = getattr(cfg.dis, 'additional_discriminators', None)
+    if add_dis_cfg is not None:
+        crop_coords = []
+        for name in add_dis_cfg:
+            v = add_dis_cfg[name].vis
+            file, crop_func = v.split('::')
+            file = importlib.import_module(file)
+            crop_func = getattr(file, crop_func)
+            crop_coord = crop_func(cfg.data, label_tensor)
+            if len(crop_coord) > 0:
+                if type(crop_coord[0]) == list:
+                    crop_coords.extend(crop_coord)
+                else:
+                    crop_coords.append(crop_coord)
+
+    pose_cfg = cfg.data.for_pose_dataset
+    pose_type = getattr(pose_cfg, 'pose_type', 'both')
+    remove_face_labels = getattr(pose_cfg, 'remove_face_labels', False)
+    label_tensor = extract_valid_pose_labels(label_tensor, pose_type,
+                                             remove_face_labels)
+
+    # If using both DensePose and OpenPose, overlay one image onto the other
+    # to get the visualization map.
+    dp_key = 'pose_maps-densepose'
+    op_key = 'poses-openpose'
+    use_densepose = use_openpose = False
+    for input_type in cfg.data.input_types:
+        if dp_key in input_type:
+            dp_ch = input_type[dp_key].num_channels
+            use_densepose = True
+        elif op_key in input_type:
+            op_ch = input_type[op_key].num_channels
+            use_openpose = True
+    if use_densepose:
+        label_img = tensor2im(label_tensor[:dp_ch])
+    if use_openpose:
+        openpose = label_tensor[-op_ch:]
+        openpose = tensor2im(openpose) if op_ch == 3 else \
+            tensor2label(openpose, op_ch)
+        if use_densepose:
+            label_img[openpose != 0] = openpose[openpose != 0]
+        else:
+            label_img = openpose
+
+    # Draw the bbox for the regions for the additional discriminator.
+    if add_dis_cfg is not None:
+        for crop_coord in crop_coords:
+            ys, ye, xs, xe = crop_coord
+            label_img[ys, xs:xe, :] = label_img[ye - 1, xs:xe, :] \
+                = label_img[ys:ye, xs, :] = label_img[ys:ye, xe - 1, :] = 255
+
+    if len(label_img.shape) == 2:
+        label_img = np.repeat(label_img[:, :, np.newaxis], 3, axis=2)
+    return label_img
diff --git a/inference/draw_points.py b/inference/draw_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..97364c42d15be50ca87b8e66fbda448371b0e032
--- /dev/null
+++ b/inference/draw_points.py
@@ -0,0 +1,38 @@
+import numpy as np
+from scipy.interpolate import interp1d
+import matplotlib.pyplot as plt
+from matplotlib.image import imread
+import csv,os,sys
+
+data = sys.argv[1]
+assert data.endswith('satView_polish.png')
+img_path = os.path.join('dataset/CVACT/satview_correct',data)
+
+# img_path = './dataset/CVACT/satview_correct/__-DFIFxvZBCn1873qkqXA_satView_polish.png'
+csv_path = 'vis_video/pixels.csv'
+select_points = [28, 44, 53]
+
+x_list,y_list = [],[]
+x_whole,y_whole = [],[]
+with open(csv_path, 'r') as csvfile:
+    reader = csv.DictReader(csvfile)
+    for i,row in enumerate(reader):
+        x,y = float(row['x']),float(row['y']) 
+        if  i in select_points:
+            x_list.append(x)
+            y_list.append(y)
+            print(i,x,y)
+        x_whole.append(x)
+        y_whole.append(y)
+fig, ax = plt.subplots()
+
+
+img = imread(img_path)
+plt.imshow(img)
+plt.plot(x_whole, y_whole, 'r-',label='Smooth curve', linewidth=4)
+plt.scatter(x_list,y_list,marker='o', s=0, color='red')
+plt.axis('off')
+plt.xlim([0, 256])
+plt.ylim([256, 0])
+plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
+plt.savefig('point_curve.png', bbox_inches='tight', pad_inches=0)
diff --git a/inference/get_score_from_dir.py b/inference/get_score_from_dir.py
new file mode 100644
index 0000000000000000000000000000000000000000..691362706ca5339dce2779917c18580ddb854bbb
--- /dev/null
+++ b/inference/get_score_from_dir.py
@@ -0,0 +1,68 @@
+from unittest import result
+from matplotlib.pyplot import hist
+from torch.utils import data
+from torch.utils.data.dataset import Dataset
+
+import os,torch
+from PIL import Image
+import torchvision.transforms as T
+from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
+import torch.nn.functional as F
+from imaginaire.evaluation.segmentation import get_segmentation_hist_model,get_miou,compute_hist
+import lpips
+from easydict import EasyDict as edict
+from tqdm import tqdm
+import piq
+from  torch.utils.data import DataLoader
+from piq import FID,KID
+import numpy as np
+
+result_path = 'result/Ours-pers-sin-sty'
+gt_path = 'dataset/CVACT/streetview_test'
+
+
+class Dataset_img(Dataset):
+    def __init__(self, dir):
+        self.dir = dir
+        self.datalist = sorted(os.listdir(dir))
+    
+    def __len__(self):
+        return len(self.datalist)
+
+    def __getitem__(self, index):
+        img = os.path.join(self.dir,self.datalist[index])
+        img = Image.open(img).convert('RGB')
+        img = T.ToTensor()(img)
+        return {'images':img}
+
+
+
+data_gt = Dataset_img(gt_path)
+data_pred = Dataset_img(result_path)
+
+
+loss_fn_alex = lpips.LPIPS(net='alex',eval_mode=True).cuda()
+loss_fn_squeeze = lpips.LPIPS(net='squeeze',eval_mode=True).cuda()
+
+
+data_list = os.listdir(result_path)
+results = edict()
+results.psnr = []
+results.ssim = []
+results.alex = []
+results.squeeze = []
+results.RMSE  = []
+
+dataloader_pred = DataLoader(data_pred,batch_size=1,shuffle=False,num_workers=10)
+dataloader_gt   = DataLoader(data_gt,batch_size=1,shuffle=False,num_workers=10)
+for i in tqdm(zip(dataloader_pred,dataloader_gt),ncols=100):
+    pred = i[0]['images'].cuda()
+    gt   = i[1]['images'].cuda()
+    results.psnr.append(-10*F.mse_loss(pred,gt).log10().item())
+    results.ssim.append(ssim(pred, gt,data_range=1.).item())
+    results.alex.append(torch.mean(loss_fn_alex((pred*2.)-1, (2.*gt)-1)).cpu().item())
+    results.squeeze.append(torch.mean(loss_fn_squeeze((pred*2.)-1, (2.*gt)-1)).cpu().item())
+    results.RMSE.append(torch.sqrt(F.mse_loss(pred,gt)).item()*255)
+
+for i in results:
+    print("%-10s"%i, ':',np.mean(results[i]))
diff --git a/inference/img2vid.py b/inference/img2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..68d33d2b8276935856875c5687af19b768adb245
--- /dev/null
+++ b/inference/img2vid.py
@@ -0,0 +1,73 @@
+import os  
+import cv2  
+from PIL import Image  
+
+def image_to_video(img_dir,image_names, media_path):
+    fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
+    fps = 20  
+    image = Image.open(os.path.join( img_dir , image_names[0]))
+    media_writer = cv2.VideoWriter(media_path, fourcc, fps, image.size)
+    for image_name in image_names:
+        im = cv2.imread(os.path.join(img_dir, image_name))
+        media_writer.write(im)
+        print(image_name, 'combined')
+    media_writer.release()
+    print('end')
+
+def img_pair2vid(sat_list,grd_list,angle_list=None,media_path= 'output.mp4'):
+    fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
+    out = cv2.VideoWriter(media_path, fourcc, 12.0, (512, 256))
+    out_sat = cv2.VideoWriter(media_path.replace('.mp4','_sat.mp4'), fourcc, 12.0, (389, 389))
+    assert len(sat_list) == len(grd_list)
+    for i  in range(len(sat_list)):
+
+        img1 = cv2.imread(os.path.join( img_dir , sat_list[i]))
+        img2 = cv2.imread(os.path.join( img_dir , grd_list[i]))
+        img3 = cv2.imread(os.path.join( img_dir , grd_list[i].replace('.png','_depth.png')))
+
+
+        if angle_list!=None:
+            angle = angle_list[i]
+            left_pixel = int((angle/180)*256)
+            if angle<0:
+                img2 = cv2.hconcat([img2[:,left_pixel:,:],img2[:,:left_pixel,:]])
+                img3= cv2.hconcat([img3[:,left_pixel:,:],img3[:,:left_pixel,:]])
+            else:
+                img2 = cv2.hconcat([img2[:,left_pixel:,:],img2[:,:left_pixel,:]])
+                img3 = cv2.hconcat([img3[:,left_pixel:,:],img3[:,:left_pixel,:]])
+        merged_image = cv2.vconcat([img2,img3])
+        out.write(merged_image)
+        out_sat.write(img1)
+    out.release()
+    out_sat.release()
+
+if __name__=='__main__':
+    import csv
+    img_dir = 'vis_video'
+    img_list = sorted(os.listdir(img_dir))
+    sat_list = []
+    grd_list = []
+    for img in img_list:
+        if '.png' in img:
+            if 'satdepth'  in img:
+                continue
+            if 'grdView_pano.png' in img:
+                continue
+            if 'grdView' in img:
+                if '_depth.png' not in img:
+                    grd_list.append(img)
+            elif 'satView' in img:
+                sat_list.append(img)
+    sat_list = sat_list[:-1]
+    grd_list = grd_list[:-1]
+    media_path = os.path.join(img_dir,'output_cat.mp4')
+    angle_list = []
+    with open(os.path.join(img_dir,'pixels.csv') , 'r') as csvfile:
+        reader = csv.DictReader(csvfile)
+        for row in reader:
+            angle = float(row['angle'])
+            angle_list.append(angle)
+    print(angle_list)
+
+    img_pair2vid(sat_list,grd_list,angle_list,media_path= media_path)
+    print('save 2 ',media_path)
\ No newline at end of file
diff --git a/inference/img2vid_interpolation.py b/inference/img2vid_interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..462dd2829e472652bd066d71b2b39bb87de5ba77
--- /dev/null
+++ b/inference/img2vid_interpolation.py
@@ -0,0 +1,26 @@
+import os  
+import cv2  
+from PIL import Image  
+
+
+def img_pair2vid(sat_list,media_path= 'interpolation.mp4'):
+    fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
+    out = cv2.VideoWriter(media_path, fourcc, 12.0, (512, 128))
+    for i  in range(len(sat_list)):
+
+        img1 = cv2.imread(os.path.join( img_dir , sat_list[i]))
+
+        out.write(img1)
+    out.release()
+
+if __name__=='__main__':
+    import csv
+    img_dir = 'vis_interpolation'
+    img_list = sorted(os.listdir(img_dir))
+    sat_list = []
+    for img in img_list:
+        sat_list.append(img)
+    media_path = os.path.join(img_dir,'interpolation.mp4')
+
+    img_pair2vid(sat_list,media_path= media_path)
+    print('save 2 ',media_path)
\ No newline at end of file
diff --git a/inference/quick_demo_interpolation.sh b/inference/quick_demo_interpolation.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ad760675784c53cd3a3b2c4df38c6efb3c2e619a
--- /dev/null
+++ b/inference/quick_demo_interpolation.sh
@@ -0,0 +1,10 @@
+CUDA_VISIBLE_DEVICES=0 python offline_train_test.py --task=test_interpolation \
+--yaml=sat2density_cvact \
+--test_ckpt_path=2u87bj8w \
+--sty_img1=YL81FiK9PucIvAkr1FHkpA_grdView.png \
+--sty_img2=pdZmLHYEhe2PHj_8-WHMhw_grdView.png \
+--demo_img=VAMM6sIEbYAY5E6ZD_RMKg_satView_polish.png \
+--data.root=demo_img
+
+
+python inference/img2vid_interpolation.py
\ No newline at end of file
diff --git a/inference/quick_demo_video.sh b/inference/quick_demo_video.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cc8d9732df05c5056582bfe4e344b07a43fb3d73
--- /dev/null
+++ b/inference/quick_demo_video.sh
@@ -0,0 +1,32 @@
+sat_name=__-DFIFxvZBCn1873qkqXA_satView_polish.png
+
+sty_name=VAMM6sIEbYAY5E6ZD_RMKg_grdView.png
+
+# select point
+# First find a starting point, then press and hold the left mouse button, draw any shape, 
+# then release the left mouse button, and press 'q' on the keyboard to end the point selection process
+
+#  better select regions near the center of the satellite image. 'q' to end select point.
+# python inference/select_points.py ${sat_name}
+
+# inference
+# if you want use illumination from another image , you could add --sty_img=WsKPDHEgLwrhrJXcUU34xA_grdView.png
+CUDA_VISIBLE_DEVICES=0 python offline_train_test.py --yaml=sat2density_cvact \
+--test_ckpt_path=2u87bj8w \
+--task=test_vid \
+--demo_img=${sat_name} --sty_img=${sty_name} \
+--data.root=demo_img
+
+
+# make video
+python inference/img2vid.py
+
+#  visualize  vis_video/volume_data.vtk with ParaView
+
+
+# python test.py --yaml=sat2density_cvact \
+#     --test_ckpt_path=2u87bj8w \
+#     --task=test_vid \
+#     --demo_img=__-DFIFxvZBCn1873qkqXA_satView_polish.png \
+#     --sty_img=VAMM6sIEbYAY5E6ZD_RMKg_grdView.png \
+#     --data.root=demo_img
diff --git a/inference/select_points.py b/inference/select_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..36cae5d190000a4ab8b4693faa26ee0cd142005d
--- /dev/null
+++ b/inference/select_points.py
@@ -0,0 +1,110 @@
+import matplotlib.pyplot as plt
+from matplotlib.widgets import Cursor
+from matplotlib.image import imread
+import numpy as np
+import csv,os
+from scipy.interpolate import interp1d
+import sys
+data = sys.argv[1]
+assert data.endswith('satView_polish.png')
+dirs = os.path.join('dataset/CVACT/satview_correct',data)
+if not os.path.exists(dirs):
+    dirs = dirs.replace('dataset/CVACT','demo_img')
+sav_pth = 'vis_video'
+if not os.path.exists(sav_pth):
+    os.mkdir(sav_pth)
+
+img = imread(dirs)
+
+fig = plt.figure()
+fig.set_size_inches(1,1,forward=False)
+ax = plt.Axes(fig, [0., 0., 1., 1.])
+ax.set_axis_off()
+ax.imshow(img)
+
+coords = []
+
+def ondrag(event):
+    if event.button != 1:
+        return
+    x, y = int(event.xdata), int(event.ydata)
+    coords.append((x, y))
+    ax.plot([x], [y], 'o', color='red')
+    fig.canvas.draw_idle()
+fig.add_axes(ax)
+cursor = Cursor(ax, useblit=True, color='red', linewidth=1)
+fig.canvas.mpl_connect('motion_notify_event', ondrag)
+plt.show()
+plt.close()
+
+
+unique_lst = list(dict.fromkeys(coords))
+pixels = []
+for x in coords:
+    if x in unique_lst:
+        if x not in pixels:
+            pixels.append(x)
+print(pixels)
+
+###########################################
+
+from scipy.interpolate import splprep, splev
+
+points = pixels
+points = np.array(points)
+tck, u = splprep(points.T, s=25, per=0)
+u_new = np.linspace(u.min(), u.max(), 80)
+x_new, y_new = splev(u_new, tck)
+
+plt.plot(points[:,0], points[:,1], 'ro', label='Original curve')
+plt.plot(x_new, y_new, 'b-', label='Smooth curve')
+plt.legend()
+plt.show()
+plt.close()
+
+fig, ax = plt.subplots()
+
+
+pixels  = [tuple(sublist[:2]) for sublist in zip(x_new,y_new)]
+###########################################
+img = imread(dirs)
+fig, ax = plt.subplots()
+ax.set_xticks([])
+ax.set_yticks([])
+ax.imshow(img)
+plt.plot(x_new, y_new, 'r-', label='Smooth curve')
+fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
+plt.savefig(os.path.join(sav_pth,os.path.basename(dirs)).replace('.png','_sat_track.png'),bbox_inches='tight', pad_inches=0)
+plt.close()
+
+###########################################
+angle_list = []
+for i,pixel in enumerate(pixels[:-1]):
+    img = imread(dirs)
+
+    x1, y1 = pixel
+    x2, y2 = pixels[i+1]
+    dx, dy = x2 - x1, y2 - y1
+    angle_save = np.degrees(np.arctan2(dy, dx))+90
+    if angle_save>180:
+        angle_save = angle_save-360
+    angle_list.append(angle_save)
+    length = np.sqrt(dx ** 2 + dy ** 2)
+    angle = np.arctan2(dy, dx) * 180 / np.pi
+    fig, ax = plt.subplots()
+    ax.set_xticks([])
+    ax.set_yticks([])
+    ax.imshow(img)
+    ax.arrow(x1, y1, dx*10, dy*10, color='red', width=length, head_width=4*length, head_length=5*length)
+    
+    name = '_sat'+'%05d' % int(i) + ".png"
+    plt.savefig(os.path.join(sav_pth,os.path.basename(dirs)).replace('.png',name),bbox_inches='tight')
+    plt.close()
+
+
+with open( os.path.join(sav_pth,'pixels.csv'), 'w', newline='') as csvfile:
+    writer = csv.writer(csvfile)
+    writer.writerow(['x', 'y','angle'])
+    for i, (x, y) in enumerate(pixels[:-1]):
+        writer.writerow([x, y,angle_list[i]])
+print('save to pixels.csv',len(pixels[:-1]))
\ No newline at end of file
diff --git a/inference/single_style_test_cvact.sh b/inference/single_style_test_cvact.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8def276ac04fc4e689e00cc288b737b9d4f7d74e
--- /dev/null
+++ b/inference/single_style_test_cvact.sh
@@ -0,0 +1,9 @@
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=t1DOCdyniuWDC5JPqm4MWA_grdView.png
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=XefdeYLN_XZEaG2VLPFVtA_grdView.png
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=H2y6n9mCg53Ip1-0_UigRQ_grdView.png
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=YOBJgPIILw9PbSFvnYZFZg_grdView.png
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=lqgXjFM3zR8EWbiWWfgjNA_grdView.png
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=fOX6greOFJMH8IlA8Gm5hg_grdView.png
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=uZoS7QAxnEGlw22PtslB_Q_grdView.png
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=FVJZ86vbU43hYf4-uM4lFg_grdView.png
+python offline_train_test.py  --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=uHD7qJude23nqRPLVrdKIA_grdView.png
\ No newline at end of file
diff --git a/inference/single_style_test_cvusa.sh b/inference/single_style_test_cvusa.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fc529e24d3c259aed28bb0e3b2073a7be85b952a
--- /dev/null
+++ b/inference/single_style_test_cvusa.sh
@@ -0,0 +1,9 @@
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0001227.jpg
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0044093.jpg
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0015421.jpg
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0040767.jpg
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0014628.jpg
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0027413.jpg
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0021324.jpg
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0013073.jpg
+CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4  --sty_img=0028546.jpg
diff --git a/inference/synthesis_video.sh b/inference/synthesis_video.sh
new file mode 100644
index 0000000000000000000000000000000000000000..51424a8517ff8ded75821a724f0aa9fe1af9667a
--- /dev/null
+++ b/inference/synthesis_video.sh
@@ -0,0 +1,16 @@
+### a demo for synthesis ground video
+name = __-DFIFxvZBCn1873qkqXA_satView_polish.png
+
+# select point
+# First find a starting point, then press and hold the left mouse button, draw any shape, 
+# then release the left mouse button, and press 'q' on the keyboard to end the point selection process
+
+#  better select regions near the center of the satellite image. 'q' to end select point.
+python inference/select_points.py ${name}
+
+# inference
+# if you want use illumination from another image , you could add --sty_img=WsKPDHEgLwrhrJXcUU34xA_grdView.png
+CUDA_VISIBLE_DEVICES=0 python offline_train_test.py --task=test_vid --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --demo_img=${name}
+
+# make video
+python img2vid.py
\ No newline at end of file
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/__pycache__/__init__.cpython-38.pyc b/model/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e731b482fc158f51e8a1b1b6b2621e4c0449ea8e
Binary files /dev/null and b/model/__pycache__/__init__.cpython-38.pyc differ
diff --git a/model/__pycache__/base_model.cpython-38.pyc b/model/__pycache__/base_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9524dcddb0cc6363017d0e966e0117927837f5f
Binary files /dev/null and b/model/__pycache__/base_model.cpython-38.pyc differ
diff --git a/model/__pycache__/craft_feature.cpython-38.pyc b/model/__pycache__/craft_feature.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62b9f2cf7b5051cb8c45316bfd596a1067d65a44
Binary files /dev/null and b/model/__pycache__/craft_feature.cpython-38.pyc differ
diff --git a/model/__pycache__/geometry_transform.cpython-38.pyc b/model/__pycache__/geometry_transform.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..921aee90ab4d8c80090f4ba55b13989d9849dbc4
Binary files /dev/null and b/model/__pycache__/geometry_transform.cpython-38.pyc differ
diff --git a/model/base_model.py b/model/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdab0f131e1881ed51578cd2027062dad910ab99
--- /dev/null
+++ b/model/base_model.py
@@ -0,0 +1,572 @@
+import os
+import torch
+from abc import ABC, abstractmethod
+import wandb
+import options
+import utils
+from pytorch_msssim import ssim, SSIM
+import numpy as np
+import torchvision
+from tqdm import tqdm
+import lpips
+from imaginaire.losses import FeatureMatchingLoss, GaussianKLLoss, PerceptualLoss,GANLoss
+import cv2
+from imaginaire.utils.trainer import get_scheduler
+from .geometry_transform import render_sat
+from model import geometry_transform
+import csv
+
+
+
+class BaseModel(ABC):
+    """This class is an abstract base class (ABC) for models.
+    To create a subclass, you need to implement the following five functions:
+        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
+        -- <set_input>:                     unpack data from dataset and apply preprocessing.
+        -- <forward>:                       produce intermediate results.
+        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
+        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
+    """
+
+    def __init__(self, opt,wandb=None):
+        """Initialize the BaseModel class.
+        Parameters:
+            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+        When creating your custom class, you need to implement your own initialization.
+        In this function, you should first call <BaseModel.__init__(self, opt)>
+        Then, you need to define four lists:
+            -- self.loss_names (str list):          specify the training losses that you want to plot and save.
+            -- self.model_names (str list):         define networks used in our training.
+        """
+        self.wandb = wandb
+        if opt.isTrain:
+            opt.save_dir =wandb.dir
+            options.save_options_file(opt,opt.save_dir)
+        self.opt = opt
+        self.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu)
+        # torch.backends.cudnn.benchmark = True
+        self.model_names = []
+        self.train_loader = None
+        self.val_loader = None 
+        self.sty_loader = None
+        self.loss_fn_alex = lpips.LPIPS(net='alex',eval_mode=True).cuda()
+        if opt.task=='test':
+            self.loss_fn_sque = lpips.LPIPS(net='squeeze',eval_mode=True).cuda()
+        self.mseloss = torch.nn.MSELoss(True,True)
+        self.criteria = {}
+        self.weights = {}
+        if hasattr(opt.optim.loss_weight, 'GaussianKL'):
+            if opt.optim.loss_weight.GaussianKL:
+                self.criteria['GaussianKL'] = GaussianKLLoss()
+                self.weights['GaussianKL'] = opt.optim.loss_weight.GaussianKL
+        if hasattr(opt.optim.loss_weight, 'L1'):
+            if opt.optim.loss_weight.L1:
+                self.criteria['L1']  = torch.nn.L1Loss(True,True)
+                self.weights['L1'] = opt.optim.loss_weight.L1
+        if hasattr(opt.optim.loss_weight, 'L2'):
+            if opt.optim.loss_weight.L2: 
+                self.criteria['L2'] = torch.nn.MSELoss(True,True)
+                self.weights['L2'] = opt.optim.loss_weight.L2
+        if hasattr(opt.optim.loss_weight, 'SSIM'):
+            if opt.optim.loss_weight.SSIM: 
+                self.criteria['SSIM'] = SSIM(data_range =1., size_average=True, channel=3)
+                self.weights['SSIM']  = opt.optim.loss_weight.SSIM
+        if hasattr(opt.optim.loss_weight, 'Perceptual'):
+            if opt.optim.loss_weight.Perceptual: 
+                self.criteria['Perceptual'] = \
+                    PerceptualLoss(
+                        network=opt.optim.perceptual_loss.mode,
+                        layers=opt.optim.perceptual_loss.layers,
+                        weights=opt.optim.perceptual_loss.weights).to(self.device)
+                self.weights['Perceptual'] = opt.optim.loss_weight.Perceptual
+        if hasattr(opt.optim.loss_weight, 'sky_inner'):
+            if opt.optim.loss_weight.sky_inner:
+                self.criteria['sky_inner'] = torch.nn.L1Loss(True,True)
+                self.weights['sky_inner'] = opt.optim.loss_weight.sky_inner
+        if hasattr(opt.optim.loss_weight, 'feature_matching'):
+            if opt.optim.loss_weight.feature_matching:
+                self.criteria['feature_matching'] = FeatureMatchingLoss()
+                self.weights['feature_matching'] = opt.optim.loss_weight.feature_matching
+        self.weights['GAN'] = opt.optim.loss_weight.GAN
+        self.criteria['GAN'] = GANLoss(gan_mode=opt.optim.gan_mode)
+
+
+    @staticmethod
+    def modify_commandline_options(parser, is_train):
+        """Add new model-specific options, and rewrite default values for existing options.
+        Parameters:
+            parser          -- original option parser
+            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+        Returns:
+            the modified parser.
+        """
+        return parser
+
+    @abstractmethod
+    def set_input(self, input):
+        """Unpack input data from the dataloader and perform necessary pre-processing steps.
+        Parameters:
+            input (dict): includes the data itself and its metadata information.
+        """
+        pass
+
+    @abstractmethod
+    def forward(self):
+        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
+        pass
+
+    def save_checkpoint(self,ep=0,latest=False):
+        """
+        save trained models.
+        Args:
+            ep (int, optional): model epochs. Defaults to 0.
+            latest (bool, optional): qhether it is the latest model. Defaults to False.
+        """        
+        ckpt_save_path = os.path.join(self.wandb.dir,'checkpoint')
+        if not os.path.exists(ckpt_save_path):
+            os.mkdir(ckpt_save_path)
+        utils.save_checkpoint(self,ep=ep,latest=latest,output_path=ckpt_save_path)
+        if not latest:
+            print("checkpoint saved: {0}, epoch {1} ".format(self.opt.name,ep))
+
+
+
+    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+        """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+        key = keys[i]
+        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
+            if module.__class__.__name__.startswith('InstanceNorm') and \
+                    (key == 'running_mean' or key == 'running_var'):
+                if getattr(module, key) is None:
+                    state_dict.pop('.'.join(keys))
+            if module.__class__.__name__.startswith('InstanceNorm') and \
+               (key == 'num_batches_tracked'):
+                state_dict.pop('.'.join(keys))
+        else:
+            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+
+    def setup_optimizer(self,opt):
+        # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
+        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.optim.lr_gen, betas=(opt.optim.beta1, 0.999),eps=1.e-7)
+        if opt.isTrain:
+            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.optim.lr_dis, betas=(opt.optim.beta1, 0.999))
+        if opt.optim.lr_policy:
+            self.sch_G = get_scheduler(opt.optim, self.optimizer_G)
+            self.sch_D = get_scheduler(opt.optim, self.optimizer_D)
+
+    def optimize_parameters(self,opt):
+        self.netG.train()
+        # update Discriminators
+        self.backward_D(opt)                # calculate gradients for D
+
+        # update Generator
+        self.backward_G(opt)                   # calculate graidents for G
+
+        psnr1 = -10*self.mseloss(self.fake_B.detach(),self.real_B.detach()).log10().item()
+        ssim_ = ssim(self.real_B.detach().float(), self.fake_B.detach().float(),data_range=1.)
+
+        out_dict = {
+                        "train_ssim": ssim_,
+                        "train_psnr1": psnr1,
+                    }
+        # adjust learning rates according to schedule 
+        if opt.optim.lr_policy:
+            out_dict["lr_D"]=self.sch_D.get_lr()[0]
+            out_dict["lr_G"]=self.sch_G.get_lr()[0]
+        out_dict.update(self.loss)
+        out_dict.update(self.dis_losses)
+        self.wandb.log(out_dict)
+
+    def validation(self,opt):
+        """Used for validation and test in Center Ground-View Synthesis setting
+
+        Args:
+            opt (_type_): option dict
+        """        
+        print(10*"*","validate",10*"*")
+        self.netG.eval()
+        # six image reconstruction metrics
+        psnr_val = []
+        ssim_val = []
+        lpips_ale_val = []
+        lpips_squ_val = []
+        rmse_val = []
+        sd_val = []
+        with torch.no_grad():
+            # set the sky of all images with predefined sky histogram.
+            if opt.sty_img:
+                for _,data in enumerate(self.sty_loader):
+                    self.set_input(data)
+                    self.style_temp=self.sky_histc
+                    break
+
+            for _,data in enumerate(tqdm(self.val_loader,ncols=100)):
+                self.set_input(data)
+                # if true: use the sky of predefined image
+                # if false: use the sky of corresponding GT
+                if opt.sty_img:
+                    self.sky_histc = self.style_temp
+                
+                self.forward(opt)
+                rmse = torch.sqrt(self.mseloss(self.fake_B*255.,self.real_B*255.)).item()
+                sd = sd_func(self.real_B,self.fake_B)
+                rmse_val.append(rmse)
+                sd_val.append(sd)
+
+                psnr1 = -10*self.mseloss(self.fake_B,self.real_B).log10().item()
+                ssim_ = ssim(self.real_B, self.fake_B,data_range=1.).item()
+                lpips_ale = torch.mean(self.loss_fn_alex((self.real_B*2.)-1, (2.*self.fake_B)-1)).cpu()
+                if opt.task=='test':
+                    lpips_sque = torch.mean(self.loss_fn_sque((self.real_B*2.)-1, (2.*self.fake_B)-1)).cpu()
+                    lpips_squ_val.append(lpips_sque)
+                psnr_val.append(psnr1)
+                ssim_val.append(ssim_)
+                lpips_ale_val.append(lpips_ale)
+                    
+                if opt.task in ['vis_test']:
+                    if not os.path.exists(opt.vis_dir):
+                        os.mkdir(opt.vis_dir)
+
+                    sat_opacity,sat_depth = render_sat(opt,self.out_put['voxel'])
+
+                    self.out_put['depth'] = (self.out_put['depth']/self.out_put['depth'].max())*255.
+                    sat_depth = (sat_depth/sat_depth.max())*255.
+                    for i in range(len(self.fake_B)):
+                        depth_save  = cv2.applyColorMap(self.out_put['depth'][i].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO)
+                        depth_sat_save = cv2.applyColorMap(sat_depth[i].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO)
+                        # cat generated ground images, GT ground images, predicted ground depth
+                        torchvision.utils.save_image([self.fake_B[i].cpu(),self.real_B[i].cpu(),torch.flip(torch.from_numpy(depth_save).permute(2,0,1)/255.,[0])],os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i])))
+                        # cat GT satellite images, predicted satellite depth
+                        torchvision.utils.save_image( [self.real_A[i].cpu() ,torch.flip(torch.from_numpy(depth_sat_save).permute(2,0,1)/255.,[0])],os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i]).rsplit('.', 1)[0]+'_sat.jpg'))
+                        # ground opacity
+                        torchvision.utils.save_image([self.out_put['opacity'][i]] ,os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i]).rsplit('.', 1)[0]+'_sat.jpg'))
+        psnr_avg = np.average(psnr_val)
+        ssim_avg = np.average(ssim_val)
+
+        lpips_ale_avg = np.average(lpips_ale_val)
+        if 'test' in opt.task:
+            lpips_squ_avg = np.average(lpips_squ_val)
+
+        rmse_avg = np.average(rmse_val)
+        sd_avg = np.average(sd_val)
+        if opt.task in ["train" , "Train"]:
+            out_dict =   {
+                            'val_psnr': psnr_avg,
+                            'val_ssim': ssim_avg,
+                            'val_lpips_ale':lpips_ale_avg,
+                            'val_rmse':rmse_avg,
+                            'val_sd':sd_avg
+                            }  
+            if opt.task=='test':
+                out_dict['val_lpips_squ'] =  lpips_squ_avg          
+            self.wandb.log(out_dict,commit=False)
+        else:
+            print(
+                {
+                'val_rmse':rmse_avg,
+                'val_ssim': ssim_avg,
+                'val_psnr': psnr_avg,
+                'val_sd':sd_avg,
+                'val_lpips_ale':lpips_ale_avg,
+                'val_lpips_squ':lpips_squ_avg,
+                }
+                )
+            with open('test_output.csv', mode='a', newline='') as csv_file:
+                writer = csv.writer(csv_file)
+                writer.writerow([rmse_avg, ssim_avg, psnr_avg, sd_avg, lpips_ale_avg, lpips_squ_avg])
+                
+    def test_vid(self,opt):
+        """Used for synthesis ground video
+
+        Args:
+            opt (_type_): option dict
+        """        
+        ckpt_list = os.listdir('wandb/')
+        for i in ckpt_list:
+            if opt.test_ckpt_path in i:
+                ckpt_path = i
+        
+        ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG']
+        print('load success!')
+        self.netG.load_state_dict(ckpt,strict=True)
+        self.netG.eval()
+        print(10*"*","test_video",10*"*")
+
+
+        pixels = []
+        if os.path.exists('vis_video/pixels.csv'):
+
+            with open('vis_video/pixels.csv', 'r') as csvfile:
+                reader = csv.DictReader(csvfile)
+                for row in reader:
+                    x = float(row['x']) #x is 
+                    y = float(row['y'])
+                    pixels.append((x, y))
+        else:
+            print('only render center point without vis_video/pixels.csv')
+            pixels = [(128,128)]
+
+        if opt.sty_img:
+            # inference with illumination from other images
+            for idx,data in enumerate(self.sty_loader):
+                self.set_input(data)
+                self.style_temp=self.sky_histc
+                break
+        with torch.no_grad():
+            for idx,data in enumerate(self.val_loader):
+                self.set_input(data)
+                if opt.sty_img:
+                    self.sky_histc = self.style_temp
+                for i,(x,y) in enumerate(pixels):
+                    opt.origin_H_W = [(y-128)/128 , (x-128)/128]
+                    print(opt.origin_H_W)
+                    self.forward(opt)
+
+
+
+                    if not os.path.exists('vis_video'):
+                        os.mkdir('vis_video')
+
+                    # save voxel to visalize & satellite depth, works well on cvact
+                    if i==0:
+                        # pre-process for better visualize
+                        volume_data = self.out_put.voxel.squeeze().cpu().numpy().transpose((1,2,0))
+                        volume_data = np.clip(volume_data, None, 10)
+
+                        import pyvista as pv
+
+                        grid = pv.UniformGrid()
+                        grid.dimensions = volume_data.shape
+                        grid.spacing = (1, 1, 1)
+                        grid.origin = (0, 0, 0)
+                        grid.point_data['values'] = volume_data.flatten(order='F')
+                        grid.save(os.path.join('vis_video',"volume_data.vtk") ) # vtk file could be visualized by ParaView app
+
+                        sat_opacity,sat_depth = render_sat(opt,self.out_put['voxel'])
+                        sat_depth = (2 - sat_depth)/(opt.data.max_height/15)*255.
+                        depth_sat_save = cv2.applyColorMap(sat_depth[0].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO)
+                        torchvision.utils.save_image(torch.flip(torch.from_numpy(depth_sat_save).permute(2,0,1)/255.,[0]) ,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png','_satdepth.png'))
+                        torchvision.utils.save_image( [self.real_A[0].cpu() ]                      ,os.path.join('vis_video',os.path.basename(self.image_paths[0]).replace('.png','_sat.png')))
+                        torchvision.utils.save_image( [self.real_B[0].cpu() ]                      ,os.path.join('vis_video',os.path.basename(self.image_paths[0]).replace('.png','_pano.png')))
+                        
+                    self.out_put['depth'] = (self.out_put['depth']/self.out_put['depth'].max())*255.
+                    depth_save  = cv2.applyColorMap(self.out_put['depth'][0].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO)
+                    depth_save = torch.flip(torch.from_numpy(depth_save).permute(2,0,1)/255.,[0])
+
+                    
+                    save_img = self.out_put.pred[0].cpu()
+                    name = '%05d' % int(i) + ".png"
+                    torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name))
+
+                    save_img = depth_save
+                    name = '%05d' % int(i) + "_depth.png"
+                    torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name))
+
+                    # save_img = self.out_put.generator_inputs[0][:3,:,:]
+                    # name = '%05d' % int(i) + "_color_project.png"
+                    # torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name))
+
+    def test_interpolation(self,opt):
+        """Used for test interpolation
+
+        Args:
+            opt (_type_): option dict
+        """        
+        ckpt_list = os.listdir('wandb/')
+        for i in ckpt_list:
+            if opt.test_ckpt_path in i:
+                ckpt_path = i
+        
+        ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG']
+        print('load success!')
+        self.netG.load_state_dict(ckpt,strict=True)
+        self.netG.eval()
+
+        pixels = [(128,128)]
+        if opt.sty_img1:
+            for idx,data in enumerate(self.sty_loader1):
+                self.set_input(data)
+                self.style_temp1=self.sky_histc
+                break
+        if opt.sty_img2:
+            for idx,data in enumerate(self.sty_loader2):
+                self.set_input(data)
+                self.style_temp2=self.sky_histc
+                break
+        
+        with torch.no_grad():
+            for idx,data in enumerate(self.val_loader):
+                self.set_input(data)
+                self.sky_histc1 = self.style_temp1
+                self.sky_histc2 = self.style_temp2
+                x,y =  pixels[0]
+                opt.origin_H_W = [(y-128)/128 , (x-128)/128]
+                print(opt.origin_H_W)
+                    
+
+                estimated_height = self.netG.depth_model(self.real_A)
+                geo_outputs = geometry_transform.render(opt,self.real_A,estimated_height,self.netG.pano_direction,PE=self.netG.PE)
+                generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth']
+                if self.netG.gen_cfg.cat_opa:
+                    generator_inputs = torch.cat((generator_inputs,opacity),dim=1)
+                if self.netG.gen_cfg.cat_depth:
+                    generator_inputs = torch.cat((generator_inputs,depth),dim=1)
+                _, _, z1 = self.netG.style_encode(self.sky_histc1)
+                _, _, z2 = self.netG.style_encode(self.sky_histc2)
+                num_inter = 60
+                for i in range(num_inter):
+                    z = z1 * (1-i/(num_inter-1)) + z2* (i/(num_inter-1))
+                    z = self.netG.style_model(z)
+                    output_RGB = self.netG.denoise_model(generator_inputs,z)
+
+                    save_img = output_RGB.cpu()
+                    name = 'img{:03d}.png'.format(i)
+                    if not os.path.exists('vis_interpolation'):
+                        os.mkdir('vis_interpolation')
+                    torchvision.utils.save_image(save_img,os.path.join('vis_interpolation',name))
+
+
+
+                        
+                  
+
+    def test_speed(self,opt):
+        self.netG.eval()
+        random_input = torch.randn(1, 3, 256, 256).to(opt.device)
+        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
+        iterations  = 300
+
+        times = torch.zeros(iterations)
+        with torch.no_grad():
+            for _ in range(50):
+                _ = self.netG(random_input,None,opt)
+            for iter in range(iterations):
+                starter.record()
+                _ = self.netG(random_input,None,opt)
+                ender.record()
+                torch.cuda.synchronize() 
+                curr_time = starter.elapsed_time(ender) # 计算时间
+                times[iter] = curr_time
+        # print(curr_time)
+
+        mean_time = times.mean().item()
+        print("Inference time: {:.6f}, FPS: {} ".format(mean_time, 1000/mean_time))
+
+
+    def test_sty(self,opt):
+        ckpt_list = os.listdir('wandb/')
+        for i in ckpt_list:
+            if opt.test_ckpt_path in i:
+                ckpt_path = i
+        
+        ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG']
+        print('load success!')
+        self.netG.load_state_dict(ckpt,strict=True)
+        self.netG.eval()
+        print(10*"*","test_sty",10*"*")
+        self.netG.eval()
+        self.style_temp_list = []
+        with torch.no_grad():
+            num_val_loader = len(self.val_loader)
+            for i in range(num_val_loader):
+                for idx,data in enumerate(tqdm(self.val_loader,ncols=100)):
+                    self.set_input(data)
+                    
+                    if i==0:
+                        self.style_temp_list.append(self.sky_histc)
+                        name = '%05d' % int(idx)
+                        torchvision.utils.save_image( [self.real_A[0].cpu() ]  ,os.path.join(opt.vis_dir,os.path.basename(self.image_paths[0]).replace('.png',name+'_sat.png')))
+                    self.sky_histc = self.style_temp_list[i]
+                    self.forward(opt)
+                    if not os.path.exists(opt.vis_dir):
+                        os.mkdir(opt.vis_dir)
+                    name = '%05d' % int(idx)+'_'+'%05d' % int(i)
+                    name= name+ '.png'
+                    torchvision.utils.save_image(self.fake_B[0].cpu(),os.path.join(opt.vis_dir, name))
+
+    def train(self,opt):
+        self.validation(opt)
+        for current_epoch in range(opt.max_epochs):
+            print(10*'-','current epoch is ',current_epoch,10*'-')
+            for idx,data in enumerate(tqdm(self.train_loader,ncols=100)):
+                self.set_input(data)
+                self.optimize_parameters(opt)
+                if idx%500==0 :
+                    out_ing_dict = {
+                                    'train_input': wandb.Image(self.real_A[0].float()),
+                                    'train_pred_and_gt': wandb.Image(torch.cat([self.fake_B,self.real_B],2)[0].float()),
+                                    }
+                    if hasattr(self.out_put, 'inter_RGB'):
+                        out_ing_dict["train_inner_pred"] = wandb.Image(self.out_put.inter_RGB[0].float())
+                    if opt.arch.gen.transform_mode in ['volum_rendering']:
+                        out_ing_dict['train_inner_opacity'] = wandb.Image(self.out_put.opacity[0].float())
+                    self.wandb.log(out_ing_dict,commit=False)
+                if  opt.optim.lr_policy.iteration_mode:
+                    self.sch_G.step()
+                    self.sch_D.step()
+            if not opt.optim.lr_policy.iteration_mode:
+                self.sch_G.step()
+                self.sch_D.step()
+            self.validation(opt)
+            if current_epoch%5==0:
+                self.save_checkpoint(ep=current_epoch)
+        self.save_checkpoint(ep=current_epoch)
+
+    def test(self,opt):
+        ckpt_list = os.listdir('wandb/')
+        for i in ckpt_list:
+            if '.zip' not in i:
+                if opt.test_ckpt_path in i:
+                    ckpt_path = i
+        
+        ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG']
+        print('load success!')
+        self.netG.load_state_dict(ckpt,strict=True)
+        # print(10*"*","validate",10*"*")
+        self.validation(opt)
+        print('if --task=vis_test,visible results will be saved,you can add "--vis_dir=xxx" to save in other dictionary',opt.vis_dir)
+
+
+    def _get_outputs(self, net_D_output, real=True):
+        r"""Return output values. Note that when the gan mode is relativistic.
+        It will do the difference before returning.
+
+        Args:
+           net_D_output (dict):
+               real_outputs (tensor): Real output values.
+               fake_outputs (tensor): Fake output values.
+           real (bool): Return real or fake.
+        """
+
+        def _get_difference(a, b):
+            r"""Get difference between two lists of tensors or two tensors.
+
+            Args:
+                a: list of tensors or tensor
+                b: list of tensors or tensor
+            """
+            out = list()
+            for x, y in zip(a, b):
+                if isinstance(x, list):
+                    res = _get_difference(x, y)
+                else:
+                    res = x - y
+                out.append(res)
+            return out
+
+        if real:
+            return net_D_output['real_outputs']
+        else:
+            return net_D_output['fake_outputs']
+
+
+def sd_func(real, fake):
+    '''
+    ref: page 6 in https://arxiv.org/abs/1511.05440
+    '''
+    dgt1 = torch.abs(torch.diff(real,dim=-2))[:, :, 1:, 1:-1]
+    dgt2 = torch.abs(torch.diff(real, dim=-1))[:, :, 1:-1, 1:]
+    dpred1 = torch.abs(torch.diff(fake, dim=-2))[:, :, 1:, 1:-1]
+    dpred2 = torch.abs(torch.diff(fake, dim=-1))[:, :, 1:-1, 1:]
+    return 10*torch.log10(1.**2/torch.mean(torch.abs(dgt1+dgt2-dpred1-dpred2))).cpu().item()
\ No newline at end of file
diff --git a/model/craft_feature.py b/model/craft_feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e073bc3c05dd0e308336ef6b41fa8c621727b11
--- /dev/null
+++ b/model/craft_feature.py
@@ -0,0 +1,146 @@
+# from this import d
+import torch
+from .base_model import BaseModel
+import importlib
+from  torch.utils.data import DataLoader
+from easydict import EasyDict as edict
+
+class Model(BaseModel):
+    def __init__(self, opt, wandb=None):
+
+        """Initialize the Generator.
+        Parameters:
+            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+        """
+        BaseModel.__init__(self, opt,wandb)
+        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
+
+
+    def set_input(self, input):
+        """Unpack input data from the dataloader and perform necessary pre-processing steps.
+        Parameters:
+            input (dict): include the data itself and its metadata information.
+            self.real_A: aerial images
+            self.real_B: ground images
+            self.image_paths: images paths of ground images
+            self.sky_mask: the sky mask of ground images
+            self.sky_histc: the histogram of selected sky
+        """     
+        self.real_A = input['sat' ].to(self.device)
+        self.real_B = input['pano'].to(self.device) if 'pano' in input else None # for testing
+        self.image_paths = input['paths']
+        if self.opt.data.sky_mask:
+            self.sky_mask = input['sky_mask'].to(self.device) if 'sky_mask' in input else None # for testing
+        if self.opt.data.histo_mode and self.opt.data.sky_mask:
+            self.sky_histc = input['sky_histc'].to(self.device) if 'sky_histc' in input else None # for testing
+        else: self.sky_histc = None
+
+    def forward(self,opt):
+        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
+        # origin_H_W is the inital localization of camera
+        if opt.task != 'test_vid':
+            opt.origin_H_W=None
+        if hasattr(opt.arch.gen,'style_inject'):
+            # replace the predicted sky with selected sky histogram
+            if opt.arch.gen.style_inject == 'histo':
+                self.out_put =  self.netG(self.real_A,self.sky_histc.detach(),opt) 
+            else:
+                raise Exception('Unknown style inject mode')
+        else:
+            self.out_put =  self.netG(self.real_A,None,opt) 
+        self.out_put = edict(self.out_put)
+        self.fake_B = self.out_put.pred
+        # perceptive image
+
+    def backward_D(self,opt):
+        """Calculate GAN loss for the discriminator"""
+        self.optimizer_D.zero_grad()
+        self.netG.eval()
+        with torch.no_grad():
+            self.forward(opt)                   
+            self.out_put.pred = self.out_put.pred.detach()
+        net_D_output = self.netD(self.real_B, self.out_put)
+
+        output_fake = self._get_outputs(net_D_output, real=False)
+        output_real = self._get_outputs(net_D_output, real=True)
+        fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True)
+        true_loss = self.criteria['GAN'](output_real, True, dis_update=True)
+        self.dis_losses = dict()
+        self.dis_losses['GAN/fake'] = fake_loss
+        self.dis_losses['GAN/true'] = true_loss
+        self.dis_losses['DIS'] = fake_loss + true_loss
+        self.dis_losses['DIS'].backward()
+        self.optimizer_D.step()          
+
+
+    def backward_G(self,opt):
+        self.optimizer_G.zero_grad()       
+        self.loss = {}
+        self.netG.train()
+        self.forward(opt) 
+        net_D_output = self.netD(self.real_B, self.out_put) 
+        pred_fake = self._get_outputs(net_D_output, real=False)
+        self.loss['GAN'] = self.criteria['GAN'](pred_fake, True, dis_update=False)
+        if 'GaussianKL' in self.criteria:
+            self.loss['GaussianKL'] = self.criteria['GaussianKL'](self.out_put['mu'], self.out_put['logvar'])
+        if 'L1' in self.criteria:
+            self.loss['L1'] = self.criteria['L1'](self.real_B,self.fake_B)
+        if 'L2' in self.criteria:
+            self.loss['L2'] = self.criteria['L2'](self.real_B,self.fake_B)
+        if 'SSIM' in self.criteria:
+            self.loss['SSIM'] = 1-self.criteria['SSIM'](self.real_B, self.fake_B)
+        if 'GaussianKL' in self.criteria:
+            self.loss['GaussianKL'] = self.criteria['GaussianKL'](self.out_put['mu'], self.out_put['logvar'])
+        if 'sky_inner' in self.criteria:
+            self.loss['sky_inner'] = self.criteria['sky_inner'](self.out_put.opacity, 1-self.sky_mask)
+        if 'Perceptual' in self.criteria:
+            self.loss['Perceptual'] = self.criteria['Perceptual'](self.fake_B,self.real_B)
+        if 'feature_matching' in self.criteria:
+            self.loss['feature_matching']  = self.criteria['feature_matching'](net_D_output['fake_features'], net_D_output['real_features'])
+        self.loss_G = 0
+        for key in self.loss:
+            self.loss_G += self.loss[key] * self.weights[key]
+        self.loss['total'] = self.loss_G 
+        self.loss_G.backward()
+        self.optimizer_G.step()             # udpate G's weights
+
+
+    def load_dataset(self,opt):
+        data = importlib.import_module("data.{}".format(opt.data.dataset))
+        if opt.task in ["train", "Train"]:
+            train_data = data.Dataset(opt,"train",opt.data.train_sub)
+            
+            self.train_loader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,num_workers=opt.data.num_workers,drop_last=True)
+            self.len_train_loader = len(self.train_loader)
+
+        val_data   = data.Dataset(opt,"val")
+        opt.batch_size = 1 if opt.task in ["test" , "val","vis_test",'test_vid','test_sty'] else opt.batch_size
+        opt.batch_size = 1 if opt.task=='test_speed' else opt.batch_size
+        self.val_loader = DataLoader(val_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.data.num_workers)
+        self.len_val_loader   = len(self.val_loader)
+        # you can select one random image as a style of all predicted skys
+        # if None, we use the corresponding style of GT 
+        if opt.sty_img:
+            sty_data = data.Dataset(opt,sty_img = opt.sty_img)
+            self.sty_loader = DataLoader(sty_data,batch_size=1,num_workers=1,shuffle=False)
+        # The followings are only used for test the illumination interpolation.
+        if opt.sty_img1:
+            sty1_data = data.Dataset(opt,sty_img = opt.sty_img1)
+            self.sty_loader1 = DataLoader(sty1_data,batch_size=1,num_workers=1,shuffle=False)
+        if opt.sty_img2:
+            sty2_data = data.Dataset(opt,sty_img = opt.sty_img2)
+            self.sty_loader2 = DataLoader(sty2_data,batch_size=1,num_workers=1,shuffle=False)
+
+    def build_networks(self, opt):
+        if 'imaginaire' in opt.arch.gen.netG:
+            lib_G = importlib.import_module(opt.arch.gen.netG)
+            self.netG = lib_G.Generator(opt).to(self.device)
+        else:
+            raise Exception('Unknown discriminator function')
+
+        if opt.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
+            if opt.arch.dis.netD == 'imaginaire.discriminators.multires_patch_pano':
+                lib_D = importlib.import_module(opt.arch.dis.netD)
+                self.netD = lib_D.Discriminator(opt.arch.dis).to(self.device)
+            else:
+                raise Exception('Unknown discriminator function')
diff --git a/model/geometry_transform.py b/model/geometry_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0a349a1cf9ee2ea858c3f608365a701583d3977
--- /dev/null
+++ b/model/geometry_transform.py
@@ -0,0 +1,272 @@
+import numpy as np
+import torch,math
+from PIL import Image
+import torchvision
+from easydict import EasyDict as edict
+
+def position_produce(opt): 
+    depth_channel =  opt.arch.gen.depth_arch.output_nc 
+    if  opt.optim.ground_prior:
+        depth_channel = depth_channel+1
+    z_ = torch.arange(depth_channel)/depth_channel
+    x_ = torch.arange(opt.data.sat_size[1])/opt.data.sat_size[1]
+    y_ = torch.arange(opt.data.sat_size[0])/opt.data.sat_size[0]
+    Z,X,Y = torch.meshgrid(z_,x_,y_)
+    input = torch.cat((Z[...,None],X[...,None],Y[...,None]),dim=-1).to(opt.device)
+    pos = positional_encoding(opt,input)
+    pos = pos.permute(3,0,1,2)
+    return  pos
+
+def positional_encoding(opt,input): # [B,...,N]
+    shape = input.shape
+    freq = 2**torch.arange(opt.arch.gen.PE_channel,dtype=torch.float32,device=opt.device)*np.pi # [L]
+    spectrum = input[...,None]*freq # [B,...,N,L]
+    sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
+    input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
+    input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
+    return input_enc
+
+
+
+def get_original_coord(opt):
+    '''
+    pano_direction [X,Y,Z] x right,y up,z out
+    '''
+    W,H  = opt.data.pano_size
+    _y = np.repeat(np.array(range(W)).reshape(1,W), H, axis=0)
+    _x = np.repeat(np.array(range(H)).reshape(1,H), W, axis=0).T
+
+    if opt.data.dataset in ['CVACT_Shi', 'CVACT', 'CVACThalf']:
+        _theta = (1 - 2 * (_x) / H) * np.pi/2 # latitude 
+    elif opt.data.dataset in ['CVUSA']:
+        _theta = (1 - 2 * (_x) / H) * np.pi/4
+    # _phi = math.pi* ( 1 -2* (_y)/W ) # longtitude 
+    _phi = math.pi*( - 0.5 - 2* (_y)/W )
+    axis0 = (np.cos(_theta)*np.cos(_phi)).reshape(H, W, 1)
+    axis1 = np.sin(_theta).reshape(H, W, 1) 
+    axis2 = (-np.cos(_theta)*np.sin(_phi)).reshape(H, W, 1) 
+    pano_direction = np.concatenate((axis0, axis1, axis2), axis=2)
+    return pano_direction  
+
+
+def render(opt,feature,voxel,pano_direction,PE=None):
+    '''
+    render ground images from ssatellite images
+    
+    feature: B,C,H_sat,W_sat feature or a input RGB
+    voxel: B,N,H_sat,W_sat density of each grid
+    PE: whether add position encoding , default is None
+    pano_direction: pano ray direction  by their definition
+    '''
+    # pano_W,pano_H = opt.data.pano_size
+    sat_W,sat_H = opt.data.sat_size
+    BS = feature.size(0)
+    ##### get origin, sample point ,depth
+
+    if opt.data.dataset =='CVACT_Shi':
+        origin_height=2       ## the height of photo taken in real world scale
+        realworld_scale = 30  ## the real world scale corresponding to [-1,1] regular cooridinate
+    elif opt.data.dataset == 'CVUSA':
+        origin_height=2       
+        realworld_scale = 55  
+    else:
+        assert Exception('Not implement yet')
+
+    assert sat_W==sat_H
+    pixel_resolution = realworld_scale/sat_W #### pixel resolution of satellite image in realworld
+
+    if opt.data.sample_total_length:
+        sample_total_length = opt.data.sample_total_length
+    else: sample_total_length = (int(max(np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(2)**2), \
+        np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(opt.data.max_height-origin_height)**2))/pixel_resolution))/(sat_W/2)
+
+    origin_z = torch.ones([BS,1])*(-1+(origin_height/(realworld_scale/2))) ### -1 is the loweast position in regular cooridinate
+    ##### origin_z: which can be definition by origin height
+    if opt.origin_H_W is None: ### origin_H_W is the photo taken space in regular coordinate
+        origin_H,origin_w = torch.zeros([BS,1]),torch.zeros([BS,1])   
+    else:
+        origin_H,origin_w = torch.ones([BS,1])*opt.origin_H_W[0],torch.ones([BS,1])*opt.origin_H_W[1]
+    origin = torch.cat([origin_w,origin_z,origin_H],dim=1).to(opt.device)[:,None,None,:]  ## w,z,h, samiliar to NERF coordinate definition
+    sample_len = ((torch.arange(opt.data.sample_number)+1)*(sample_total_length/opt.data.sample_number)).to(opt.device)
+    ### sample_len:  For sample distance is fixed, so we can easily calculate sample len along a way by max length and sample number
+    origin = origin[...,None]
+    pano_direction = pano_direction[...,None] ### the direction has been normalized
+    depth = sample_len[None,None,None,None,:]
+    sample_point = origin + pano_direction * depth #0.0000],-0.8667],0.0000 w,z,h
+    # x points right, y points up, z points backwards scene nerf
+    # ray_depth = sample_point-origin
+
+    if opt.optim.ground_prior:
+        voxel = torch.cat([torch.ones(voxel.size(0),1,voxel.size(2),voxel.size(3),device=opt.device)*1000,voxel],1)
+
+            # voxel[:,0,:,:] = 100
+    N = voxel.size(1)
+    voxel_low = -1
+    voxel_max = -1 + opt.data.max_height/(realworld_scale/2)  ### voxel highest space in normal space
+    grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] ### BS,NUM_point,W,H,3 
+    grid[...,2]   = ((grid[...,2]-voxel_low)/(voxel_max-voxel_low))*2-1  ### grid_space change to sample space by scale the z space
+    grid = grid.float()  ## [1, 300, 256, 512, 3]
+    
+    color_input = feature.unsqueeze(2).repeat(1, 1, N, 1, 1)
+    alpha_grid = torch.nn.functional.grid_sample(voxel.unsqueeze(1), grid)
+
+    color_grid = torch.nn.functional.grid_sample(color_input, grid)
+    if PE is not None:
+        PE_grid = torch.nn.functional.grid_sample(PE[None,...], grid[:1,...])
+        color_grid = torch.cat([color_grid,PE_grid.repeat(BS, 1, 1, 1, 1)],dim=1)
+
+    depth_sample = depth.permute(0,1,2,4,3).view(1,-1,opt.data.sample_number,1)
+    feature_size = color_grid.size(1)
+    color_grid = color_grid.permute(0,3,4,2,1).view(BS,-1,opt.data.sample_number,feature_size)
+    alpha_grid = alpha_grid.permute(0,3,4,2,1).view(BS,-1,opt.data.sample_number)
+    intv = sample_total_length/opt.data.sample_number
+    output = composite(opt, rgb_samples=color_grid,density_samples=alpha_grid,depth_samples=depth_sample,intv = intv)
+    output['voxel']  = voxel
+    return output
+
+def composite(opt,rgb_samples,density_samples,depth_samples,intv):
+    """generate 2d ground images according to ray
+
+    Args:
+        opt (_type_): option dict
+        rgb_samples (_type_): rgb (sampled from satellite image) belongs to the ray which start from the ground camera to world
+        density_samples (_type_): density (sampled from the predicted voxel of satellite image) belongs to the ray which start from the ground camera to world
+        depth_samples (_type_): depth of the ray which start from the ground camera to world
+        intv (_type_): interval of the ray's depth which start from the ground camera to world
+
+    Returns:
+        2d ground images (rgd, opacity, and depth)
+    """    
+    
+    sigma_delta = density_samples*intv # [B,HW,N]
+    alpha = 1-(-sigma_delta).exp_() # [B,HW,N]
+    T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)) .exp_() # [B,HW,N]
+    prob = (T*alpha)[...,None] # [B,HW,N,1]
+    # integrate RGB and depth weighted by probability
+    depth = (depth_samples*prob).sum(dim=2) # [B,HW,1]
+    rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3]
+    opacity = prob.sum(dim=2) # [B,HW,1]
+    depth = depth.permute(0,2,1).view(depth.size(0),-1,opt.data.pano_size[1],opt.data.pano_size[0])
+    rgb = rgb.permute(0,2,1).view(rgb.size(0),-1,opt.data.pano_size[1],opt.data.pano_size[0])
+    opacity = opacity.view(opacity.size(0),1,opt.data.pano_size[1],opt.data.pano_size[0])
+    return {'rgb':rgb,'opacity':opacity,'depth':depth}
+
+
+def get_sat_ori(opt):
+    W,H  = opt.data.sat_size
+    y_range =  (torch.arange(H,dtype=torch.float32,)+0.5)/(0.5*H)-1
+    x_range = (torch.arange(W,dtype=torch.float32,)+0.5)/(0.5*H)-1
+    Y,X = torch.meshgrid(y_range,x_range)
+    Z = torch.ones_like(Y)
+    xy_grid = torch.stack([X,Z,Y],dim=-1)[None,:,:]
+    return xy_grid
+
+def render_sat(opt,voxel):
+    '''
+    voxel: voxel has been processed
+    '''
+    # pano_W,pano_H = opt.data.pano_size
+    sat_W,sat_H = opt.data.sat_size
+    sat_ori  = get_sat_ori(opt)
+    sat_dir  = torch.tensor([0,-1,0])[None,None,None,:]
+
+    ##### get origin, sample point ,depth
+    if opt.data.dataset =='CVACT_Shi':
+        origin_height=2      
+        realworld_scale = 30  
+    elif opt.data.dataset == 'CVUSA':
+        origin_height=2       
+        realworld_scale = 55  
+
+    else:
+        assert Exception('Not implement yet')
+
+    pixel_resolution = realworld_scale/sat_W #### pixel resolution of satellite image in realworld
+    # if opt.data.sample_total_length:
+    #     sample_total_length = opt.data.sample_total_length
+    # else: sample_total_length = (int(max(np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(2)**2), \
+    #     np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(opt.data.max_height-origin_height)**2))/pixel_resolution))/(sat_W/2)
+    sample_total_length = 2
+    # #### sample_total_length: it can be definition in future, which is the farest length between sample point and original ponit 
+    # assert sat_W==sat_H
+
+    origin = sat_ori.to(opt.device)  ## w,z,h, samiliar to NERF coordinate definition
+    sample_len = ((torch.arange(opt.data.sample_number)+1)*(sample_total_length/opt.data.sample_number)).to(opt.device)
+    ### sample_len:  For sample distance is fixed, so we can easily calculate sample len along a way by max length and sample number
+    origin = origin[...,None].to(opt.device)
+    direction = sat_dir[...,None].to(opt.device) ### the direction has been normalized
+    depth = sample_len[None,None,None,None,:]
+    sample_point = origin + direction * depth #0.0000],-0.8667],0.0000 w,z,h
+
+
+    N = voxel.size(1)
+    voxel_low = -1
+    voxel_max = -1 + opt.data.max_height/(realworld_scale/2)  ### voxel highest space in normal space
+    # axis_voxel = (torch.arange(N)/N) * (voxel_max-voxel_low) +voxel_low
+    grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] ### BS,NUM_point,W,H,3 
+    grid[...,2]   = ((grid[...,2]-voxel_low)/(voxel_max-voxel_low))*2-1  ### grid_space change to sample space by scale the z space
+    grid = grid.float()  ## [1, 300, 256, 512, 3]
+    alpha_grid = torch.nn.functional.grid_sample(voxel.unsqueeze(1), grid)
+
+    depth_sample = depth.permute(0,1,2,4,3).view(1,-1,opt.data.sample_number,1)
+    alpha_grid = alpha_grid.permute(0,3,4,2,1).view(opt.batch_size,-1,opt.data.sample_number)
+    # color_grid = torch.flip(color_grid,[2])
+    # alpha_grid = torch.flip(alpha_grid,[2])
+    intv = sample_total_length/opt.data.sample_number
+    output = composite_sat(opt,density_samples=alpha_grid,depth_samples=depth_sample,intv = intv)
+    return output['opacity'],output['depth']
+
+def composite_sat(opt,density_samples,depth_samples,intv):
+    sigma_delta = density_samples*intv # [B,HW,N]
+    alpha = 1-(-sigma_delta).exp_() # [B,HW,N]
+    T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)) .exp_() # [B,HW,N]
+    prob = (T*alpha)[...,None] # [B,HW,N,1]
+    depth = (depth_samples*prob).sum(dim=2) # [B,HW,1]
+    opacity = prob.sum(dim=2) # [B,HW,1]
+    depth = depth.permute(0,2,1).view(depth.size(0),-1,opt.data.sat_size[1],opt.data.sat_size[0])
+    opacity = opacity.view(opacity.size(0),1,opt.data.sat_size[1],opt.data.sat_size[0])
+    # return rgb,depth,opacity,prob # [B,HW,K]
+    return {'opacity':opacity,'depth':depth}
+
+if __name__ == '__main__':
+    # test_demo
+    opt=edict()
+    opt.device = 'cuda'
+    opt.data = edict()
+    opt.data.pano_size = [512,256]
+    opt.data.sat_size = [256,256]
+    opt.data.dataset = 'CVACT_Shi'
+    opt.data.max_height = 20
+    opt.data.sample_number = 300
+    opt.arch = edict()
+    opt.optim = edict()
+    opt.optim.ground_prior = False
+    opt.arch.gen.transform_mode = 'volum_rendering'
+    # opt.arch.gen.transform_mode = 'proj_like_radus'
+    BS = 1
+    opt.data.sample_total_length = 1
+    sat_name = './CVACT/satview_correct/__-DFIFxvZBCn1873qkqXA_satView_polish.png'
+    a = Image.open(sat_name)
+    a = np.array(a).astype(np.float32)
+    a = torch.from_numpy(a)
+    a = a.permute(2, 0, 1).unsqueeze(0).to(opt.device).repeat(BS,1,1,1)/255.
+
+
+    pano = sat_name.replace('satview_correct','streetview').replace('_satView_polish','_grdView')
+    pano = np.array(Image.open(pano)).astype(np.float32)
+    pano = torch.from_numpy(pano)
+    pano = pano.permute(2, 0, 1).unsqueeze(0).to(opt.device).repeat(BS,1,1,1)/255.
+    voxel=torch.zeros([BS, 65, 256, 256]).to(opt.device)
+    pano_direction = torch.from_numpy(get_original_coord(opt)).unsqueeze(0).to(opt.device)
+
+    import time
+    star = time.time()
+    with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=False) as prof:
+        rgb,opacity =render(opt,a,voxel,pano_direction)
+    print(prof.table())
+      
+    print(time.time()-star) 
+
+    torchvision.utils.save_image(torch.cat([rgb,pano],2), opt.arch.gen.transform_mode + '.png')
+    print( opt.arch.gen.transform_mode + '.png')
+    torchvision.utils.save_image(opacity, 'opa.png')
\ No newline at end of file
diff --git a/model/sample.py b/model/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ddfa80dbee7d1f54a76a9b39fa17c428bac4f81
--- /dev/null
+++ b/model/sample.py
@@ -0,0 +1,110 @@
+import cv2
+import numpy as np
+import torch
+import torchvision
+
+class Equirectangular():
+    """
+    Random sample a panorama image into a perspective view
+    take https://github.com/fuenwang/Equirec2Perspec/blob/master/Equirec2Perspec.py as a reference
+    """
+    def __init__(self, width = 256, height = 256, FovX = 100, theta = [0, 0]):
+        """
+        width: output image's width
+        height: output image's height
+        FovX: perspective camera FOV on x-axis (degree)
+        theta: theta field where img's theta degree from 
+        """
+        self.theta = theta
+        self.width = width
+        self.height = height
+        self.type = type
+
+        #create x-axis coordinates and corresponding y-axis coordinates
+        x = np.arange(width)
+        y = np.arange(height)
+        x, y = np.meshgrid(x, y) 
+        
+        #create homogenerous coordinates
+        z = np.ones_like(x)
+        xyz = np.concatenate([x[..., None], y[..., None], z[..., None]], axis=-1)
+        
+        #translation matrix
+        f = 0.5 * width * 1 / np.tan(np.radians(FovX/2))
+        # cx = (width - 1) / 2.0
+        # cy = (height - 1) / 2.0
+        cx = (width) / 2.0
+        cy = (height) / 2.0        
+        K = np.array([
+                [f, 0, cx],
+                [0, f, cy],
+                [0, 0,  1],
+            ], np.float32)
+        K_inv = np.linalg.inv(K)
+        xyz = xyz @ K_inv.T
+        self.xyz = xyz  ### self.xyz is the direction of the each ray in the camera space when camera is fixed
+
+
+
+    def __call__(self, img1): 
+        batch = img1.shape[0]
+        PHI, THETA = self.getRandomRotation(batch)
+        y_axis = np.array([0.0, 1.0, 0.0], np.float32)
+        x_axis = np.array([1.0, 0.0, 0.0], np.float32)
+        #rotation matrix
+        xy_grid = []
+        for i in range(batch):
+            R1, _ = cv2.Rodrigues(y_axis * np.radians(PHI[i]))
+            R2, _ = cv2.Rodrigues(np.dot(R1, x_axis) * np.radians(THETA[i]))
+            R = R2 @ R1
+            #rotate
+            xyz = self.xyz @ R.T  ### ### xyz is the direction of the each ray in the camera space when camera is rotate
+            norm = np.linalg.norm(xyz, axis=-1, keepdims=True)
+            xyz_norm = xyz / norm
+            
+            #transfer to image coordinates
+            xy = self.xyz2xy(xyz_norm)
+            device = img1.device
+            xy = torch.from_numpy(xy).to(device).unsqueeze(0)
+            xy_grid.append(xy)
+        xy = torch.cat(xy_grid,dim=0)
+
+        #resample
+        return xy
+
+    def xyz2xy(self, xyz_norm):
+        #normlize
+        x = xyz_norm[..., 0]
+        y = xyz_norm[..., 1]
+        z = xyz_norm[..., 2]
+
+        lon = np.arctan2(x, z)
+        lat = np.arcsin(y)
+        ### transfer to the lon and lat
+
+        X = lon / (np.pi)
+        Y = lat / (np.pi) * 2
+        xy = np.stack([X, Y], axis=-1)
+        xy = xy.astype(np.float32)
+        
+        return xy
+
+    def getRandomRotation(self,batch_size):
+        # phi = np.random.rand(batch_size) * 360 -180
+        phi = np.random.randint(-180,180,batch_size)
+        assert(self.theta[0]<self.theta[1])
+        theta = np.random.randint(self.theta[0],self.theta[1],batch_size)
+        # theta = np.random.rand(batch_size)*(self.theta[1]-self.theta[0])-self.theta[0]
+        return phi, theta
+
+
+if __name__=='__main__':
+    # test demo
+    e = Equirectangular(theta=[0., 40.],width = 64, height = 64,FovX = 100)
+    img = cv2.imread('dataset/CVACT/streetview/__-DFIFxvZBCn1873qkqXA_grdView.png')[:,:,::-1]/255.0
+    img = img.transpose(2, 0, 1).astype(np.float32)
+    
+    img = torch.from_numpy(img).unsqueeze(0).repeat(10, 1, 1, 1)
+    equ= e(img) 
+    # print(PHI, THETA)   
+    torchvision.utils.save_image(torch.nn.functional.grid_sample(img, equ.float(), align_corners = True)*0.99, 'test_30.png')
\ No newline at end of file
diff --git a/offline_train_test.py b/offline_train_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2a5d5fe43428ac282f62957185a6317697f2bbf
--- /dev/null
+++ b/offline_train_test.py
@@ -0,0 +1,57 @@
+import os,sys
+import importlib
+import options
+from utils import log
+import warnings
+warnings.filterwarnings("ignore")
+os.environ['WANDB_IGNORE_GLOBS'] = '*.pth'
+os.environ['WANDB_MODE'] = 'dryrun'
+
+def main():
+    log.process(os.getpid())
+    log.title("[{}] (PyTorch code for testing Sat2Density and debug".format(sys.argv[0]))
+    opt_cmd = options.parse_arguments(sys.argv[1:])
+    opt = options.set(opt_cmd=opt_cmd)
+    if opt.test_ckpt_path and opt.task not in ["test" , "val","vis_test",'test_speed','test_vid','test_sty','test_interpolation']:
+        opt.task = "test"
+    if opt.task in ["train" , "Train"]:
+        opt.isTrain = True
+    else:
+        opt.isTrain = False
+
+    opt.name = opt.yaml if opt.name is None else opt.name
+    mode = importlib.import_module("model.{}".format(opt.model))
+    m = mode.Model(opt)
+
+    m.load_dataset(opt)
+    m.build_networks(opt)
+    # train
+    if opt.task in ["train" , "Train"]:
+        m.setup_optimizer(opt)
+        m.train(opt)
+
+    # test or visualization
+    elif opt.task in ["test" , "val","vis_test"]:
+        m.test(opt)
+
+    # test speed
+    elif opt.task == 'test_speed':
+        m.test_speed(opt)
+    # inference video results
+    elif opt.task == 'test_vid':
+        m.test_vid(opt)
+    # test one image with different styles
+    elif opt.task == 'test_sty':
+        m.test_sty(opt)
+    # test style interpolation
+    elif opt.task == 'test_interpolation':
+        m.test_interpolation(opt)    
+    else:
+        raise Exception("Unknow task")
+
+
+
+        
+
+if __name__=="__main__":
+    main()
diff --git a/options.py b/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..f14b6d0d6bb7d2000a5c2ad12148770c30f08085
--- /dev/null
+++ b/options.py
@@ -0,0 +1,119 @@
+import numpy as np
+import os
+import torch
+import random
+import string
+import yaml
+from easydict import EasyDict as edict
+
+import utils
+from utils import log
+
+def parse_arguments(args):
+    """
+    Parse arguments from command line.
+    Syntax: --key1.key2.key3=value --> value
+            --key1.key2.key3=      --> None
+            --key1.key2.key3       --> True
+            --key1.key2.key3!      --> False
+    """
+    opt_cmd = {}
+    for arg in args:
+        assert(arg.startswith("--"))
+        if "=" not in arg[2:]:
+            key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true")
+        else:
+            key_str,value = arg[2:].split("=")
+        keys_sub = key_str.split(".")
+        opt_sub = opt_cmd
+        for k in keys_sub[:-1]:
+            if k not in opt_sub: opt_sub[k] = {}
+            opt_sub = opt_sub[k]
+        assert keys_sub[-1] not in opt_sub,keys_sub[-1]
+        opt_sub[keys_sub[-1]] = yaml.safe_load(value)
+    opt_cmd = edict(opt_cmd)
+    return opt_cmd
+
+def set(opt_cmd={}):
+    log.info("setting configurations...")
+    # load config from yaml file
+    assert("yaml" in opt_cmd)
+    fname = "options/{}.yaml".format(opt_cmd.yaml)
+    opt_base = load_options(fname)
+    # override with command line arguments
+    opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True)
+    process_options(opt)
+    log.options(opt)
+    return opt
+
+def load_options(fname):
+    with open(fname) as file:
+        opt = edict(yaml.safe_load(file))
+    if "_parent_" in opt:
+        # load parent yaml file(s) as base options
+        parent_fnames = opt.pop("_parent_")
+        if type(parent_fnames) is str:
+            parent_fnames = [parent_fnames]
+        for parent_fname in parent_fnames:
+            opt_parent = load_options(parent_fname)
+            opt_parent = override_options(opt_parent,opt,key_stack=[])
+            opt = opt_parent
+    print("loading {}...".format(fname))
+    return opt
+
+def override_options(opt,opt_over,key_stack=None,safe_check=False):
+    for key,value in opt_over.items():
+        print(key,value)
+        if isinstance(value,dict):
+            # parse child options (until leaf nodes are reached)
+            opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check)
+        else:
+            # ensure command line argument to override is also in yaml file
+            if safe_check and key not in opt:
+                add_new = None
+                while add_new not in ["y","n"]:
+                    key_str = ".".join(key_stack+[key])
+                    add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str))
+                if add_new=="n":
+                    print("safe exiting...")
+                    exit()
+            opt[key] = value
+    return opt
+
+def process_options(opt):
+    # set seed
+    if opt.seed is not None:
+        random.seed(opt.seed)
+        np.random.seed(opt.seed)
+        torch.manual_seed(opt.seed)
+        torch.cuda.manual_seed_all(opt.seed)
+    else:
+        # create random string as run ID
+        randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4))
+        opt.name = str(opt.name)+"_{}".format(randkey)
+    assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough
+    opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu)
+
+def save_options_file(opt,output_path):
+    opt_fname = "{}/options.yaml".format(output_path)
+    if os.path.isfile(opt_fname):
+        with open(opt_fname) as file:
+            opt_old = yaml.safe_load(file)
+        if opt!=opt_old:
+            # prompt if options are not identical
+            opt_new_fname = "{}/options_temp.yaml".format(output_path)
+            with open(opt_new_fname,"w") as file:
+                yaml.safe_dump(utils.to_dict(opt),file,default_flow_style=False,indent=4)
+            print("existing options file found (different from current one)...")
+            os.system("diff {} {}".format(opt_fname,opt_new_fname))
+            os.system("rm {}".format(opt_new_fname))
+            override = None
+            while override not in ["y","n"]:
+                override = input("override? (y/n) ")
+            if override=="n":
+                print("safe exiting...")
+                exit()
+        else: print("existing options file found (identical)")
+    else: print("(creating new options file...)")
+    with open(opt_fname,"w") as file:
+        yaml.safe_dump(utils.to_dict(opt),file,default_flow_style=False,indent=4)
diff --git a/options/base.yaml b/options/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a7a95869c8e09d9afb0990f3cf8df5c362fb68e0
--- /dev/null
+++ b/options/base.yaml
@@ -0,0 +1,57 @@
+project: 0_test                                               
+name:                                                      
+model: shi                                                     
+yaml:                                                       
+seed: 6004                                                  
+task: Train 
+Group: DEBUG                                                
+gpu: 0                                                      
+cpu: false                                                  
+load:                                                       
+arch: {}                                                    
+test_ckpt_path:
+demo_img: __-DFIFxvZBCn1873qkqXA_grdView.png
+
+# for testing
+sty_img:
+
+sky_img:
+
+# only for illumination interpolation visualization
+sty_img1:
+sty_img2:
+sky_img1:
+sky_img2:
+
+data:                                                       
+    sky_mask:                                               
+    root:                                                   
+    dataset:                                                
+    num_workers: 24                                         
+    histo_mode:
+    sample_total_length:
+    train_sub:
+
+optim:
+    lr_gen: 0.0001                                   # learning rate (main)
+    lr_dis: 0.0004
+    gan_mode: hinge
+    beta1: 0    
+    ground_prior:
+    loss_weight:
+        GAN: 1
+        L1: 1
+    perceptual_loss:
+        mode: 'vgg19'
+        layers: ['relu_3_1', 'relu_4_1', 'relu_5_1']
+        weights: [0.125, 0.25, 1.0]
+    lr_policy:
+        iteration_mode: False
+        type: step
+        step_size: 13
+        gamma: 0.1
+batch_size: 16                                              # batch size
+resume: false                                               # not test
+fp16:                                                       # not test
+vis_dir: 'vis'
+max_epochs: 30
diff --git a/options/sat2density_cvact.yaml b/options/sat2density_cvact.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87c7de7ca0e77a0b21e758fd48bf87e04f9adb4d
--- /dev/null
+++ b/options/sat2density_cvact.yaml
@@ -0,0 +1,91 @@
+_parent_: options/base.yaml
+
+gpu_ids: '0'
+
+## config for wandb
+project: 'sat2pano'
+Group: 'craft_feature'
+
+
+model: craft_feature                           # model/craft_feature
+arch:   
+  gen:                                         ## config for generator
+    netG: imaginaire.generators.craft_2stage_add_style  
+    weight_norm_type: spectral
+    activation_norm_type: instance
+    padding_mode: reflect
+    transform_mode: volum_rendering
+    feature_model:
+    style_inject: histo                        # use histogram to inject illumination, chose list [histo, perspective]
+    cat_PE: 
+    cat_opa: true
+    cat_depth: true
+    depth_arch:                                # Density Net
+      name: depth
+      num_filters: 32                          
+      num_downsamples: 4                       
+      num_res_blocks: 6                        
+      output_nc: 64
+    render_arch:                               # Render Net
+      name: render
+      num_filters: 64                          
+      num_downsamples: 4                       
+      num_res_blocks: 9                        
+      output_nc: 3
+    style_enc_cfg:                              # style injection
+      input_image_channels: 3
+      num_filters: 256
+      kernel_size: 3
+      style_dims: 128
+      interm_style_dims: 256
+      hidden_channel: 256
+      weight_norm_type: spectral
+  dis:                                          # discriminator
+    netD: imaginaire.discriminators.multires_patch_pano
+    num_filters: 64
+    max_num_filters: 512
+    num_discriminators: 3
+    num_layers: 3
+    weight_norm_type: spectral
+    activation_norm_type: instance
+
+
+
+data:                                            # data options
+  dataset: CVACT_Shi                           # dataset name
+  root: ./dataset/CVACT/ 
+  sat_size: [256,256]              
+  pano_size: [512, 128]
+  sample_number: 100                           # points per ray
+  max_height: 8                                # pre-defined density space in height axis
+  sky_mask: true                               
+  histo_mode: rgb
+  # val: 
+  #   sub: 500        
+
+
+optim:   
+  lr_gen: 0.00005                                
+  lr_dis: 0.00005
+  gan_mode: non_saturated                         #'hinge', 'least_square',  'non_saturated', 'wasserstein'
+  loss_weight:
+    L1: 1   
+    L2: 10
+    GaussianKL: 0.1
+    feature_matching: 10.0
+    Perceptual: 10
+    sky_inner: 1 
+    GAN: 1                                              
+
+  lr_policy:
+    iteration_mode: False                         # iteration or epoch
+    type: step
+    step_size: 45                         
+    gamma: 0.1
+
+  ground_prior: true 
+
+######## for test, if only style, will random choice one style for save dir
+only_style: 
+only_img:
+save_dir: 
\ No newline at end of file
diff --git a/options/sat2density_cvusa.yaml b/options/sat2density_cvusa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..398aecc4a2d73763933cc7eb2c36ff3d84b396cb
--- /dev/null
+++ b/options/sat2density_cvusa.yaml
@@ -0,0 +1,10 @@
+_parent_: options/sat2density_cvact.yaml
+
+
+arch:
+  dis:
+    num_discriminators: 2
+
+data:
+    dataset: CVUSA                                       
+    root: ./dataset/CVUSA/ 
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3db762bf06fb5e9eb3016f8649fe0306bb84e599
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,18 @@
+cmake
+numpy==1.22.1
+Pillow>=8.3.2
+scipy
+scikit-image
+tqdm==4.35.0
+cython
+qimage2ndarray
+requests==2.25.1
+tensorboard
+lpips
+easydict
+termcolor
+wandb==0.13.3
+pytorch_msssim
+opencv-contrib-python==4.6.0.66
+albumentations
+pyvista
\ No newline at end of file
diff --git a/scripts/INSTALL.md b/scripts/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..7f2f095ebf5e18fb6d369c7e7b2a8285dd854cfa
--- /dev/null
+++ b/scripts/INSTALL.md
@@ -0,0 +1,11 @@
+1. Andconda
+2. install Cuda11.1 and cudnn (requested by [imaginaire](https://github.com/NVlabs/imaginaire))
+   If you are not root user, you can install cuda in you '~' path
+3. make sure the 'nvcc -V' \& 'cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2'
+    have a right output
+4. apt-get update \& apt-get install cmake (if you are not in docker content, you should add 'sudo before every apt-get')
+5. conda activate your-env-name (make sure python>3.8)
+6. pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu111
+   or https://mirror.sjtu.edu.cn/pytorch-wheels/cu111/torch_stable.html
+7. test  'torch.cuda.is_available()'
+8. bash scripts/install
\ No newline at end of file
diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a6adb8ed2a92efb3931f8e57cc49f5b5974823a2
--- /dev/null
+++ b/scripts/build_docker.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+key=${1}
+
+rm -rf Dockerfile
+echo "FROM nvcr.io/nvidia/pytorch:${key}-py3" > Dockerfile
+input="Dockerfile.base"
+
+while IFS= read -r line
+do
+  echo "$line" >> Dockerfile
+done < "$input"
+
+input="scripts/requirements.txt"
+while IFS= read -r line
+do
+  echo "RUN pip install $line" >> Dockerfile
+done < "$input"
+
+
+for p in correlation channelnorm resample2d bias_act upfirdn2d; do
+  echo "COPY imaginaire/third_party/$p $p" >> Dockerfile
+  echo "RUN cd $p && rm -rf build dist *-info && python setup.py install" >> Dockerfile
+done
+
+# Compile GANcraft libraries.
+echo "COPY imaginaire/model_utils/gancraft/voxlib gancraft/voxlib" >> Dockerfile
+echo "RUN cd gancraft/voxlib && make" >> Dockerfile
+
+docker build -t nvcr.io/nvidian/lpr-imagine/imaginaire:${key}-py3 .
diff --git a/scripts/build_index.py b/scripts/build_index.py
new file mode 100644
index 0000000000000000000000000000000000000000..427701dc963c4d6afbd5e5a1301de54da1687889
--- /dev/null
+++ b/scripts/build_index.py
@@ -0,0 +1,67 @@
+import argparse
+import json
+import os
+import sys
+
+sys.path.append('.')
+from imaginaire.utils.lmdb import create_metadata  # noqa: E402
+from imaginaire.config import Config  # noqa: E402
+
+
+def parse_args():
+    r"""Parse user input arguments"""
+    parser = argparse.ArgumentParser(description='Folder -> LMDB conversion')
+    parser.add_argument('--data_root', type=str, required=True,
+                        help='Input data location.')
+    parser.add_argument('--output_root', type=str, default='',
+                        help='Input data location.')
+    parser.add_argument('--config', type=str, required=True,
+                        help='Config with label info.')
+    parser.add_argument('--paired', default=False, action='store_true',
+                        help='Is the input data paired?')
+    parser.add_argument('--input_list', type=str, default='',
+                        help='list of images that will be used.')
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    r""" Build lmdb for training/testing.
+    Usage:
+    python scripts/build_index.py \
+      --data_root /mnt/bigdata01/datasets/test_image \
+      --output_root /mnt/bigdata01/datasets/test_image/lmdb_0/ \
+      --overwrite
+    """
+    args = parse_args()
+    if args.output_root == '':
+        args.output_root = args.data_root
+    cfg = Config(args.config)
+
+    all_filenames, extensions = \
+        create_metadata(
+            data_root=args.data_root,
+            cfg=cfg,
+            paired=args.paired,
+            input_list=args.input_list)
+
+    os.makedirs(args.output_root, exist_ok=True)
+
+    if args.paired:
+        base = args.data_root.split('/')[-1]
+        new_all_filenames = dict()
+        for key in all_filenames.keys():
+            new_all_filenames['{}/{}'.format(base, key)] = all_filenames[key]
+        all_filenames = new_all_filenames.copy()
+
+    # Output list of all filenames.
+    with open(args.output_root + '/all_filenames.json', 'w') as fout:
+        json.dump(all_filenames, fout, indent=4)
+
+    # Output metadata.
+    with open(args.output_root + '/metadata.json', 'w') as fout:
+        json.dump(extensions, fout, indent=4)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/build_lmdb.py b/scripts/build_lmdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..544afaa7b2b77f78f64c9e82ad60278437b00d21
--- /dev/null
+++ b/scripts/build_lmdb.py
@@ -0,0 +1,125 @@
+import copy
+import shutil
+import argparse
+import json
+import sys
+import os
+from tqdm import tqdm
+
+sys.path.append('.')
+from imaginaire.utils.lmdb import create_metadata, \
+    construct_file_path, check_and_add, build_lmdb  # noqa: E402
+from imaginaire.config import Config  # noqa: E402
+
+
+def parse_args():
+    r"""Parse user input arguments"""
+    parser = argparse.ArgumentParser(description='Folder -> LMDB conversion')
+    parser.add_argument('--data_root', type=str, required=True,
+                        help='Input data location.')
+    parser.add_argument('--config', type=str, required=True,
+                        help='Config with label info.')
+    parser.add_argument('--output_root', type=str, required=True,
+                        help='Output LMDB location')
+    parser.add_argument('--input_list', type=str, default='',
+                        help='list of images that will be used.')
+    parser.add_argument('--metadata_factor', type=float, default=0.75,
+                        help='Factor of filesize to allocate for metadata?')
+    parser.add_argument('--overwrite', default=False, action='store_true',
+                        help='Overwrite output file if exists')
+    parser.add_argument('--paired', default=False, action='store_true',
+                        help='Is the input data paired?')
+    parser.add_argument('--large', default=False, action='store_true',
+                        help='Is the dataset large?')
+    parser.add_argument('--remove_missing', default=False, action='store_true',
+                        help='Remove missing files from paired datasets?')
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    r""" Build lmdb for training/testing.
+    Usage:
+    python scripts/build_lmdb.py \
+      --config configs/data_image.yaml \
+      --data_root /mnt/bigdata01/datasets/test_image \
+      --output_root /mnt/bigdata01/datasets/test_image/lmdb_0/ \
+      --overwrite
+    """
+    args = parse_args()
+    cfg = Config(args.config)
+
+    # Check if output file already exists.
+    if os.path.exists(args.output_root):
+        if args.overwrite:
+            print('Deleting existing output LMDB.')
+            shutil.rmtree(args.output_root)
+        else:
+            print('Output root LMDB already exists. Use --overwrite. ' +
+                  'Exiting...')
+            return
+
+    all_filenames, extensions = \
+        create_metadata(data_root=args.data_root,
+                        cfg=cfg,
+                        paired=args.paired,
+                        input_list=args.input_list)
+    required_data_types = cfg.data.data_types
+
+    # Build LMDB.
+    os.makedirs(args.output_root)
+    for data_type in required_data_types:
+        data_size = 0
+        print('Data type:', data_type)
+        filepaths, keys = [], []
+        print('>> Building file list.')
+
+        # Get appropriate list of files.
+        if args.paired:
+            filenames = all_filenames
+        else:
+            filenames = all_filenames[data_type]
+
+        for sequence in tqdm(filenames):
+            for filename in copy.deepcopy(filenames[sequence]):
+                filepath = construct_file_path(
+                    args.data_root, data_type, sequence, filename,
+                    extensions[data_type])
+                key = '%s/%s' % (sequence, filename)
+                filesize = check_and_add(filepath, key, filepaths, keys,
+                                         remove_missing=args.remove_missing)
+
+                # Remove file from list, if missing.
+                if filesize == -1 and args.paired and args.remove_missing:
+                    print('Removing %s from list' % (filename))
+                    filenames[sequence].remove(filename)
+                data_size += filesize
+
+        # Remove empty sequences.
+        if args.paired and args.remove_missing:
+            for sequence in copy.deepcopy(all_filenames):
+                if not all_filenames[sequence]:
+                    all_filenames.pop(sequence)
+
+        # Allocate size.
+        data_size = max(int((1 + args.metadata_factor) * data_size), 1e9)
+        print('Reserved size: %s, %dGB' % (data_type, data_size // 1e9))
+
+        # Write LMDB to file.
+        output_filepath = os.path.join(args.output_root, data_type)
+        build_lmdb(filepaths, keys, output_filepath, data_size, args.large)
+
+    # Output list of all filenames.
+    if args.output_root:
+        with open(args.output_root + '/all_filenames.json', 'w') as fout:
+            json.dump(all_filenames, fout, indent=4)
+
+        # Output metadata.
+        with open(args.output_root + '/metadata.json', 'w') as fout:
+            json.dump(extensions, fout, indent=4)
+    else:
+        return all_filenames, extensions
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/build_lmdb.sh b/scripts/build_lmdb.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a749cd18124634cb2de57433f1b9d16ca291919c
--- /dev/null
+++ b/scripts/build_lmdb.sh
@@ -0,0 +1,9 @@
+MODEL=$1
+DATASET=$2
+
+for SPLIT in test train; do
+  RAW=dataset/${DATASET}_raw/${SPLIT}
+  LMDB=dataset/${DATASET}/${SPLIT}
+  echo ${LMDB}
+  python scripts/build_lmdb.py --config configs/projects/${MODEL}/${DATASET}/ampO1.yaml --data_root ${RAW} --output_root ${LMDB} --overwrite
+done
\ No newline at end of file
diff --git a/scripts/download_dataset.py b/scripts/download_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..156910fca512b2e789236e2e1fe94bcf5ff90473
--- /dev/null
+++ b/scripts/download_dataset.py
@@ -0,0 +1,49 @@
+import argparse
+import os
+import tarfile
+import sys
+
+sys.path.append('.')
+from imaginaire.utils.io import download_file_from_google_drive  # noqa: E402
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Download and process dataset')
+    parser.add_argument('--dataset', help='Name of the dataset.', required=True,
+                        choices=['afhq_dog2cat',
+                                 'animal_faces'])
+    parser.add_argument('--data_dir', default='./dataset',
+                        help='Directory to save all datasets.')
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+    if args.dataset == 'afhq_dog2cat':
+        url = '1XaiwS0eRctqm-JEDezOBy4TXriAQgc4_'
+    elif args.dataset == 'animal_faces':
+        url = '1ftr1xWm0VakGlLUWi7-hdAt9W37luQOA'
+    else:
+        raise ValueError('Invalid dataset {}.'.format(args.dataset))
+
+    # Create the dataset directory.
+    if not os.path.exists(args.data_dir):
+        os.makedirs(args.data_dir)
+
+    # Download the compressed dataset.
+    folder_path = os.path.join(args.data_dir, args.dataset + '_raw')
+    compressed_path = folder_path + '.tar.gz'
+    if not os.path.exists(compressed_path) and not os.path.exists(folder_path):
+        print("Downloading the dataset {}.".format(args.dataset))
+        download_file_from_google_drive(url, compressed_path)
+
+    # Extract the dataset.
+    if not os.path.exists(folder_path):
+        print("Extracting the dataset {}.".format(args.dataset))
+        with tarfile.open(compressed_path) as tar:
+            tar.extractall(folder_path)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/download_test_data.py b/scripts/download_test_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8cede6105c60509ed38fc7d7b473a5c782044a3
--- /dev/null
+++ b/scripts/download_test_data.py
@@ -0,0 +1,60 @@
+import argparse
+import os
+import sys
+import tarfile
+sys.path.append('.')
+from imaginaire.utils.io import download_file_from_google_drive  # noqa: E402
+
+URLS = {
+    'pix2pixhd': '1Xg9m184zkuG8H0LHdBtSzt2VbMi3SWwR',
+    'spade': '1ESm-gHWu_aMHnKF42qkGc8qf1SBECsgf',
+    'funit': '1a-EE_6RsYPUoKxEl5oXrpRmKYUltqaD-',
+    'coco_funit': '1JYVYB0Q1VStDLOb0SBJbN1vkaf6KrGDh',
+    'unit': '17BbwnCG7qF7FI-t9VkORv2XCKqlrY1CO',
+    'munit': '1VPgHGuQfmm1N1Vh56wr34wtAwaXzjXtH',
+    'vid2vid': '1SHvGPMq-55GDUQ0Ac2Ng0eyG5xCPeKhc',
+    'fs_vid2vid': '1fTj0HHjzcitgsSeG5O_aWMF8yvCQUQkN',
+    'wc_vid2vid/cityscapes': '1KKzrTHfbpBY9xtLqK8e3QvX8psSdrFcD',
+    'wc_vid2vid/mannequin': '1mafZf9KJrwUGGI1kBTvwgehHSqP5iaA0',
+    'gancraft': '1m6q7ZtYJjxFL0SQ_WzMbvoLZxXmI5_vJ',
+}
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Download test data.')
+    parser.add_argument('--model_name', required=True,
+                        help='Name of the model.')
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+    test_data_dir = 'projects/' + args.model_name + '/test_data'
+    print(test_data_dir)
+    assert args.model_name in URLS, 'No sample test data available'
+    url = URLS[args.model_name]
+
+    if os.path.exists(test_data_dir):
+        print('Test data exists at', test_data_dir)
+        compressed_path = test_data_dir + '.tar.gz'
+        # Extract the dataset.
+        print('Extracting test data to', test_data_dir)
+        with tarfile.open(compressed_path) as tar:
+            tar.extractall(path=test_data_dir)
+    else:
+        os.makedirs(test_data_dir, exist_ok=True)
+        # Download the compressed dataset.
+        compressed_path = test_data_dir + '.tar.gz'
+        if not os.path.exists(compressed_path):
+            print('Downloading test data to', compressed_path)
+            download_file_from_google_drive(url, compressed_path)
+
+        # Extract the dataset.
+        print('Extracting test data to', test_data_dir)
+        with tarfile.open(compressed_path) as tar:
+            tar.extractall(path=test_data_dir)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/download_weights.sh b/scripts/download_weights.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7cac24c86c499d0aede03adf46c3d246ce9eb008
--- /dev/null
+++ b/scripts/download_weights.sh
@@ -0,0 +1,16 @@
+CHECKPOINTS="run-20230219_141512-2u87bj8w.zip"
+
+if [ ! -d "wandb" ]; then
+  mkdir wandb
+fi
+
+for checkpoint in $CHECKPOINTS ; do
+    echo "Downloading $checkpoint";
+    if [ ! -f "wandb/$checkpoint" ]; then
+        wget https://github.com/sat2density/checkpoints/releases/download/cvusa/$checkpoint -P wandb
+    fi
+    echo "Unzipping $checkpoint";
+    if [ ! -d "wandb/${checkpoint%.*}" ]; then
+        unzip wandb/$checkpoint -d wandb
+    fi
+done
\ No newline at end of file
diff --git a/scripts/install.bat b/scripts/install.bat
new file mode 100644
index 0000000000000000000000000000000000000000..c5e40f6fdf6ea7ddf8beb84574b06846b7308bea
--- /dev/null
+++ b/scripts/install.bat
@@ -0,0 +1,25 @@
+@ECHO OFF
+FOR /F "tokens=*" %%g IN ('nvcc --version') do (set ver=%%g)
+
+echo %ver%
+set CUDA_VERSION=%ver:~11,4%
+echo %CUDA_VERSION%
+
+pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+pip install --upgrade -r scripts/requirements.txt
+
+echo %cd%
+set curr_directory=%cd%
+echo %curr_directory%
+
+for %%p in (correlation channelnorm resample2d bias_act upfirdn2d) do (
+  cd %curr_directory%
+  cd imaginaire\third_party\%%p
+  rmdir /s /q build dist *info
+  python setup.py install
+  cd %curr_directory%
+)
+
+
+
diff --git a/scripts/install.sh b/scripts/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b6fe8e261838c64f60d6911ecaa9a76821a7448a
--- /dev/null
+++ b/scripts/install.sh
@@ -0,0 +1,37 @@
+#!/bin/sh
+CURRENT=$(pwd)
+
+# Check CUDA_VERSION
+export CUDA_VERSION=$(nvcc --version| grep -Po "(\d+\.)+\d+" | head -1)
+
+apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \
+        build-essential \
+        git \
+        curl \
+        vim \
+        tmux \
+        wget \
+        bzip2 \
+        unzip \
+        g++ \
+        ca-certificates \
+        ffmpeg \
+        libx264-dev \
+        imagemagick
+
+pip install --upgrade -r scripts/requirements.txt
+# pip install -U git+https://github.com/pyvista/pyvista.git@main
+
+
+for p in correlation channelnorm resample2d bias_act upfirdn2d; do
+  cd imaginaire/third_party/${p};
+  rm -rf build dist *info;
+  python setup.py install;
+  cd ${CURRENT};
+done
+
+# for p in gancraft/voxlib; do
+#   cd imaginaire/model_utils/${p};
+#   make all
+#   cd ${CURRENT};
+# done
diff --git a/scripts/start_local_docker.sh b/scripts/start_local_docker.sh
new file mode 100644
index 0000000000000000000000000000000000000000..78dca4657ef90b26726aea5b9e20e79213128e2c
--- /dev/null
+++ b/scripts/start_local_docker.sh
@@ -0,0 +1,10 @@
+docker run \
+    --gpus all \
+    --shm-size 32g \
+    --ipc=host \
+    -it \
+    -v /mnt:/mnt \
+    -v ~/:/home \
+    nvcr.io/nvidian/lpr-imagine/imaginaire:${1}-py3 \
+    /bin/bash
+
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..965a1c5ff5f3abd9809f164e4afff43eac64f110
--- /dev/null
+++ b/test.py
@@ -0,0 +1,350 @@
+import importlib
+import os
+import os.path as osp
+import sys
+import warnings
+
+import torch
+
+import options
+from utils import log
+
+warnings.filterwarnings("ignore")
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torchvision.transforms as transforms
+from matplotlib.widgets import Cursor
+from PIL import Image
+from scipy.interpolate import interp1d, splev, splprep
+from torch.utils.data import default_convert,default_collate
+import torchvision
+
+from model.geometry_transform import render_sat,render
+import cv2 
+import imageio 
+
+def get_checkpoint(opt):
+    if opt.test_ckpt_path == '2u87bj8w':
+        opt.test_ckpt_path = osp.join('wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth')
+    elif opt.test_ckpt_path == '2cqv8uh4':
+        opt.test_ckpt_path = osp.join('wandb/run-20230303_142752-2cqv8uh4/files/checkpoint/model.pth')
+    else:
+        pass
+
+
+def img_read(img,size=None,datatype='RGB'):
+    img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
+    if size:
+        if type(size) is int:
+            size = (size,size)
+        img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
+    img = transforms.ToTensor()(img)
+    return img
+
+def select_points(sat_image):
+    fig = plt.figure()
+    fig.set_size_inches(1,1,forward=False)
+    ax = plt.Axes(fig, [0., 0., 1., 1.])
+    ax.set_axis_off()
+    ax.imshow(sat_image)
+
+    coords = []
+
+    def ondrag(event):
+        if event.button != 1:
+            return
+        x, y = int(event.xdata), int(event.ydata)
+        coords.append((x, y))
+        ax.plot([x], [y], 'o', color='red')
+        fig.canvas.draw_idle()
+        
+    fig.add_axes(ax)
+    cursor = Cursor(ax, useblit=True, color='red', linewidth=1)
+    fig.canvas.mpl_connect('motion_notify_event', ondrag)
+    plt.show()
+    plt.close()
+
+    unique_lst = list(dict.fromkeys(coords))
+    pixels = []
+    for x in coords:
+        if x in unique_lst:
+            if x not in pixels:
+                pixels.append(x)
+    print(pixels)
+    pixels = np.array(pixels)
+    tck, u = splprep(pixels.T, s=25, per=0)
+    u_new = np.linspace(u.min(), u.max(), 80)
+    x_new, y_new = splev(u_new, tck)
+
+    smooth_path = np.array([x_new,y_new]).T
+    
+    angles = np.arctan2(y_new[1:]-y_new[:-1],x_new[1:]-x_new[:-1])
+    
+    return pixels, angles, smooth_path
+
+def volume2pyvista(volume_data):
+    import pyvista as pv 
+    grid = pv.UniformGrid()
+    grid.dimensions = volume_data.shape
+    grid.spacing = (1, 1, 1)
+    grid.origin = (0, 0, 0)
+    grid.point_data['values'] = volume_data.flatten(order='F')
+    return grid
+
+
+def img_pair2vid(sat_list,save_dir,media_path= 'interpolation.mp4'):
+    fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
+    out = cv2.VideoWriter(media_path, fourcc, 12.0, (512, 128))
+    for i  in range(len(sat_list)):
+
+        img1 = cv2.imread(os.path.join( save_dir , sat_list[i]))
+
+        out.write(img1)
+    out.release()
+
+@torch.no_grad()
+def test_vid(model, opt):
+    ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
+    model.netG.load_state_dict(ckpt['netG'])
+    model.netG.eval()
+    
+    # for idx, data in enumerate(model.val_loader):
+    #     import pdb; pdb.set_trace()
+    demo_imgpath = opt.demo_img 
+    sty_imgpath = opt.sty_img 
+    if opt.sky_img is None:
+        sky_imgpath = opt.sty_img.replace('image','sky')
+    else:
+        sky_imgpath = opt.sky_img
+
+    sat = img_read(demo_imgpath, size=opt.data.sat_size)
+    pano = img_read(sty_imgpath, size=opt.data.pano_size)
+
+    input_dict = {}
+    input_dict['sat'] = sat
+    input_dict['pano'] = pano
+    input_dict['paths'] = demo_imgpath
+
+
+    if opt.data.sky_mask:
+        sky = img_read(sky_imgpath, size=opt.data.pano_size, datatype='L') 
+        input_a = pano*sky
+        sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
+        input_dict['sky_histc'] = sky_histc
+        input_dict['sky_mask'] = sky
+    else:
+        sky_histc = None
+    
+    for key in input_dict.keys():
+        if isinstance(input_dict[key], torch.Tensor):
+            input_dict[key] = input_dict[key].unsqueeze(0)
+
+    model.set_input(input_dict)
+    
+    model.style_temp = model.sky_histc
+    
+    pixels, angles, smooth_path = select_points(sat_image=sat.permute(1,2,0).numpy())
+
+    rendered_image_list = []
+    rendered_depth_list = []
+    
+
+    volume_data = None
+
+    for i, (x,y) in enumerate(pixels):
+        opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future
+        print('Rendering at ({}, {})'.format(x,y))
+        model.forward(opt)
+
+        rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().numpy().transpose((1,2,0))
+        rgb = np.array(rgb*255, dtype=np.uint8)
+        rendered_image_list.append(rgb)
+
+        rendered_depth_list.append(
+            model.out_put.depth[0,0].cpu().numpy()
+        )
+        
+
+    sat_opacity, sat_depth = render_sat(opt,model.out_put.voxel)
+    
+    volume_data = model.out_put.voxel[0].cpu().numpy().transpose((1,2,0))
+    volume_data = np.clip(volume_data, None, 10)
+    
+    volume_export = volume2pyvista(volume_data)
+
+    os.makedirs(opt.save_dir, exist_ok=True)
+    volume_export.save(os.path.join(opt.save_dir, 'volume.vtk'))
+
+    # save rendered images 
+    os.makedirs(osp.join(opt.save_dir,'rendered_images'), exist_ok=True)
+
+    for i, img in enumerate(rendered_image_list):
+        plt.imsave(osp.join(opt.save_dir,'rendered_images','{:05d}.png'.format(i)), img)
+
+    os.makedirs(osp.join(opt.save_dir,'rendered_depth'), exist_ok=True)
+
+    os.makedirs(osp.join(opt.save_dir,
+    'rendered_images+depths'), exist_ok=True)
+
+    for i, img in enumerate(rendered_depth_list):
+        depth = np.array(img/img.max()*255,dtype=np.uint8)
+        depth = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO)
+        plt.imsave(osp.join(opt.save_dir,'rendered_depth','{:05d}.png'.format(i)), depth)
+        image_and_depth = np.concatenate((rendered_image_list[i], depth), axis=0)
+
+        plt.imsave(osp.join(opt.save_dir,'rendered_images+depths','{:05d}.png'.format(i)), image_and_depth)
+    
+    os.makedirs(osp.join(opt.save_dir,'sat_images'), exist_ok=True)
+    
+    for i, (x,y) in enumerate(pixels):
+        
+        
+        # plt.plot(x, y, 'o', color='red')
+
+        sat_rgb = sat.permute(1,2,0).numpy()
+        sat_rgb = np.array(sat_rgb*255, dtype=np.uint8)
+        fig = plt.figure()
+        fig.set_size_inches(1,1,forward=False)
+        ax = plt.Axes(fig, [0., 0., 1., 1.])
+        ax.set_axis_off()
+        ax.imshow(sat_rgb)
+        ax.plot(pixels[:i+1,0], pixels[:i+1,1], 'r-', color='red')
+        ax.plot(x, y, 'o', color='red', markersize=2)
+        # if i < len(pixels)-1:
+        # #     ax.plot([x,pixels[0,0]],[y,pixels[0,1]],'r-')
+        # # else:
+        #     ax.plot([x,pixels[i+1,0]],[y,pixels[i+1,1]],'r-')
+        fig.add_axes(ax)
+        plt.savefig(osp.join(opt.save_dir,'sat_images','{:05d}.png'.format(i)),bbox_inches='tight', pad_inches=0, dpi=256)
+        
+    print('Done')
+
+
+@torch.no_grad()
+def test_interpolation(model,opt):
+    ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
+    model.netG.load_state_dict(ckpt['netG'])
+    model.netG.eval()
+
+
+
+
+    sat = img_read(opt.demo_img , size=opt.data.sat_size)
+    pano1 = img_read(opt.sty_img1 , size=opt.data.pano_size)
+    pano2 = img_read(opt.sty_img2 , size=opt.data.pano_size)
+    
+
+    input_dict = {}
+    input_dict['sat'] = sat
+    input_dict['paths'] = opt.demo_img 
+
+    # black_ground = torch.zeros_like(pano1)
+    sky_imgpath1 = opt.sty_img1.replace('image','sky')
+    sky_imgpath2 = opt.sty_img2.replace('image','sky')
+
+    sky = img_read(sky_imgpath1, size=opt.data.pano_size, datatype='L') 
+    input_a = pano1*sky
+    sky_histc1 = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
+
+    # for idx in range(len(input_a)):
+    #     if idx == 0:
+    #         sky_histc1 = input_a[idx].histc()[10:]
+    #     else:
+    #         sky_histc1 = torch.cat([input_a[idx].histc()[10:],sky_histc1],dim=0)
+
+    sky = img_read(sky_imgpath2, size=opt.data.pano_size, datatype='L') 
+    input_b = pano2*sky
+    sky_histc2 = torch.cat([input_b[i].histc()[10:] for i in reversed(range(3))])
+    # for idx in range(len(input_b)):
+    #     if idx == 0:
+    #         sky_histc2 = input_b[idx].histc()[10:]
+    #     else:
+    #         sky_histc2 = torch.cat([input_b[idx].histc()[10:],sky_histc2],dim=0)
+
+    for key in input_dict.keys():
+        if isinstance(input_dict[key], torch.Tensor):
+            input_dict[key] = input_dict[key].unsqueeze(0)
+
+    model.set_input(input_dict)
+    pixels = [(128,128)]
+    
+    x,y =  pixels[0]
+    opt.origin_H_W = [(y-128)/128 , (x-128)/128]
+    print(opt.origin_H_W)
+
+    estimated_height = model.netG.depth_model(model.real_A)
+    geo_outputs = render(opt,model.real_A,estimated_height,model.netG.pano_direction,PE=model.netG.PE)
+    generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth']
+    if model.netG.gen_cfg.cat_opa:
+        generator_inputs = torch.cat((generator_inputs,opacity),dim=1)
+    if model.netG.gen_cfg.cat_depth:
+        generator_inputs = torch.cat((generator_inputs,depth),dim=1)
+    _, _, z1 = model.netG.style_encode(sky_histc1.unsqueeze(0).to(model.device))
+    _, _, z2 = model.netG.style_encode(sky_histc2.unsqueeze(0).to(model.device))
+    num_inter = 60
+    for i in range(num_inter):
+        z = z1 * (1-i/(num_inter-1)) + z2* (i/(num_inter-1))
+        z = model.netG.style_model(z)
+        output_RGB = model.netG.denoise_model(generator_inputs,z)
+
+        save_img = output_RGB.cpu()
+        name = 'img{:03d}.png'.format(i)
+        torchvision.utils.save_image(save_img,os.path.join(opt.save_dir,name))
+
+    img_list = sorted(os.listdir(opt.save_dir))
+    sat_list = []
+    for img in img_list:
+        sat_list.append(img)
+    media_path = os.path.join(opt.save_dir,'interpolation.mp4')
+
+    img_pair2vid(sat_list,opt.save_dir,media_path)
+    print('Done, save 2 ',media_path)
+
+def main():
+    log.process(os.getpid())
+    log.title("[{}] (PyTorch code for testing Sat2Density and debug".format(sys.argv[0]))
+    
+    opt_cmd = options.parse_arguments(sys.argv[1:])
+    opt = options.set(opt_cmd=opt_cmd)
+    opt.isTrain = False
+    opt.name = opt.yaml if opt.name is None else opt.name
+    opt.batch_size = 1
+
+    if opt.save_dir is None:
+        raise Exception("Please specify the save dir")
+
+    get_checkpoint(opt)
+
+    mode = importlib.import_module("model.{}".format(opt.model))
+    m = mode.Model(opt)
+
+    # m.load_dataset(opt)
+    m.build_networks(opt)
+
+    if os.path.exists(opt.save_dir):
+        import shutil
+        shutil.rmtree(opt.save_dir)
+    if opt.task == 'test_vid':
+        test_vid(m, opt)
+    if opt.task == 'test_interpolation':
+        assert opt.sty_img1
+        assert opt.sty_img2
+        os.makedirs(opt.save_dir, exist_ok=True)
+        test_interpolation(m,opt)
+    
+    # import pdb; pdb.set_trace()
+    
+    # print(m)
+    # # test or visualization
+    # if opt.task == 'test_vid':
+    #     m.test_vid(opt)
+    # elif opt.task == 'test_sty':
+    #     m.test_sty(opt)
+    # elif opt.task == 'test_interpolation':
+    #     m.test_interpolation(opt)
+    # else:
+    #     raise RuntimeError("Unknow task")
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e9dfac88fde48243d964d294b548989c999836
--- /dev/null
+++ b/train.py
@@ -0,0 +1,30 @@
+import os,sys
+import importlib
+import options
+import warnings
+import wandb
+warnings.filterwarnings("ignore")
+os.environ['WANDB_IGNORE_GLOBS'] = '*.pth'  # not save checkpoint in cloud
+
+def main():
+    opt_cmd = options.parse_arguments(sys.argv[1:])
+    opt = options.set(opt_cmd=opt_cmd)
+    assert opt.task in ["train","Train"]
+    opt.isTrain = True
+    opt.name = opt.yaml if opt.name is None else opt.name
+    wandb_log = wandb.init(
+        project=opt.project,
+        name=opt.name,
+        group=opt.Group,
+        config=opt,
+        )
+    mode = importlib.import_module("model.{}".format(opt.model))
+    m = mode.Model(opt,wandb_log)
+
+    m.load_dataset(opt)
+    m.build_networks(opt)
+    m.setup_optimizer(opt)
+    m.train(opt)
+    
+if __name__=="__main__":
+    main()
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4967e14d5cd14425c34a32392f4c939e90db3fea
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,166 @@
+import termcolor,os,shutil,torch
+from easydict import EasyDict as edict
+from collections import OrderedDict
+import math
+import numpy as np
+from torch.nn import init
+
+def get_time(sec):
+    """
+    Convert seconds to days, hours, minutes, and seconds
+    """
+    d = int(sec//(24*60*60))
+    h = int(sec//(60*60)%24)
+    m = int((sec//60)%60)
+    s = int(sec%60)
+    return d,h,m,s
+
+# convert to colored strings
+def red(message,**kwargs): return termcolor.colored(str(message),color="red",attrs=[k for k,v in kwargs.items() if v is True])
+def green(message,**kwargs): return termcolor.colored(str(message),color="green",attrs=[k for k,v in kwargs.items() if v is True])
+def blue(message,**kwargs): return termcolor.colored(str(message),color="blue",attrs=[k for k,v in kwargs.items() if v is True])
+def cyan(message,**kwargs): return termcolor.colored(str(message),color="cyan",attrs=[k for k,v in kwargs.items() if v is True])
+def yellow(message,**kwargs): return termcolor.colored(str(message),color="yellow",attrs=[k for k,v in kwargs.items() if v is True])
+def magenta(message,**kwargs): return termcolor.colored(str(message),color="magenta",attrs=[k for k,v in kwargs.items() if v is True])
+def grey(message,**kwargs): return termcolor.colored(str(message),color="grey",attrs=[k for k,v in kwargs.items() if v is True])
+
+
+
+def openreadtxt(file_name):
+    
+    file = open(file_name,'r')  
+    file_data = file.read().splitlines() 
+    return file_data
+
+def to_dict(D,dict_type=dict):
+    D = dict_type(D)
+    for k,v in D.items():
+        if isinstance(v,dict):
+            D[k] = to_dict(v,dict_type)
+    return D
+
+class Log:
+    def __init__(self): pass
+    def process(self,pid):
+        print(grey("Process ID: {}".format(pid),bold=True))
+    def title(self,message):
+        print(yellow(message,bold=True,underline=True))
+    def info(self,message):
+        print(magenta(message,bold=True))
+    def options(self,opt,level=0):
+        for key,value in sorted(opt.items()):
+            if isinstance(value,(dict,edict)):
+                print("   "*level+cyan("* ")+green(key)+":")
+                self.options(value,level+1)
+            else:
+                print("   "*level+cyan("* ")+green(key)+":",yellow(value))
+    def loss_train(self,opt,ep,lr,loss,timer):
+        if not opt.max_epoch: return
+        message = grey("[train] ",bold=True)
+        message += "epoch {}/{}".format(cyan(ep,bold=True),opt.max_epoch)
+        message += ", lr:{}".format(yellow("{:.2e}".format(lr),bold=True))
+        message += ", loss:{}".format(red("{:.3e}".format(loss),bold=True))
+        message += ", time:{}".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.elapsed)),bold=True))
+        message += " (ETA:{})".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.arrival))))
+        print(message)
+    def loss_val(self,opt,loss):
+        message = grey("[val] ",bold=True)
+        message += "loss:{}".format(red("{:.3e}".format(loss),bold=True))
+        print(message)
+log = Log()
+
+def save_checkpoint(model,ep,latest=False,children=None,output_path=None):
+
+    os.makedirs("{0}/model".format(output_path),exist_ok=True)
+    checkpoint = dict(
+        epoch=ep,
+        netG=model.netG.state_dict(),
+        netD=model.netD.state_dict()
+        )
+
+    torch.save(checkpoint,"{0}/model.pth".format(output_path))
+    if not latest:
+        shutil.copy("{0}/model.pth".format(output_path),
+                    "{0}/model/{1}.pth".format(output_path,ep)) # if ep is None, track it instead
+
+def filt_ckpt_keys(ckpt, item_name, model_name):
+    # if item_name in ckpt:
+    assert item_name in ckpt, "Cannot find [%s] in the checkpoints." % item_name
+    d = ckpt[item_name]
+    d_filt = OrderedDict()
+    for k, v in d.items():
+        k_list = k.split('.')
+        if k_list[0] == model_name:
+            if k_list[1] == 'module':
+                d_filt['.'.join(k_list[2:])] = v
+            else:
+                d_filt['.'.join(k_list[1:])] = v
+    return d_filt
+
+def requires_grad(model, flag=True):
+    for p in model.parameters():
+        p.requires_grad = flag
+
+def filt_ckpt_keys(ckpt, item_name, model_name):
+    # if item_name in ckpt:
+    assert item_name in ckpt, "Cannot find [%s] in the checkpoints." % item_name
+    d = ckpt[item_name]
+    d_filt = OrderedDict()
+    for k, v in d.items():
+        k_list = k.split('.')
+        if k_list[0] == model_name:
+            if k_list[1] == 'module':
+                d_filt['.'.join(k_list[2:])] = v
+            else:
+                d_filt['.'.join(k_list[1:])] = v
+    return d_filt
+
+def get_ray_pano(batch_img):
+    _,_,H,W = batch_img.size()
+    _y = np.repeat(np.array(range(W)).reshape(1,W), H, axis=0)
+    _x = np.repeat(np.array(range(H)).reshape(1,H), W, axis=0).T
+    
+    _theta = (1 - 2 * (_x) / H) * np.pi/2 # latitude
+    _phi = 2*math.pi*(0.5 - (_y)/W ) # longtitude
+    axis0 = (np.cos(_theta)*np.cos(_phi)).reshape(1,H, W)
+    axis1 = np.sin(_theta).reshape(1,H, W)
+    axis2 = (-np.cos(_theta)*np.sin(_phi)).reshape(1, H, W)
+    original_coord = np.concatenate((axis0, axis1, axis2), axis=0)
+
+    return original_coord
+
+def init_weights(net, init_type='kaiming', init_gain=0.02):
+    """Initialize network weights.
+    Parameters:
+        net (network)   -- network to be initialized
+        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
+    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+    work better for some applications. Feel free to try yourself.
+    """
+    def init_func(m):  # define the initialization function
+        classname = m.__class__.__name__
+        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+            if init_type == 'normal':
+                init.normal_(m.weight.data, 0.0, init_gain)
+            elif init_type == 'xavier':
+                init.xavier_normal_(m.weight.data, gain=init_gain)
+            elif init_type == 'kaiming':
+                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+            elif init_type == 'orthogonal':
+                init.orthogonal_(m.weight.data, gain=init_gain)
+            else:
+                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+            if hasattr(m, 'bias') and m.bias is not None:
+                init.constant_(m.bias.data, 0.0)
+        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+            init.normal_(m.weight.data, 1.0, init_gain)
+            init.constant_(m.bias.data, 0.0)
+
+    print('initialize network with %s' % init_type)
+    net.apply(init_func)
+
+
+if __name__=='__main__':
+    a = torch.zeros([2,3,200,100])
+    cood = get_ray_pano(a)