diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..aae7e776f35576220e72ab7b866e471fb9a702aa
Binary files /dev/null and b/.DS_Store differ
diff --git a/CodeFormer/.DS_Store b/CodeFormer/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..667d79118ce06e32d60daf6c577ce427c2e9eb9c
Binary files /dev/null and b/CodeFormer/.DS_Store differ
diff --git a/CodeFormer/.gitignore b/CodeFormer/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..18b62a49768403d1a155456e487b22491d1554cb
--- /dev/null
+++ b/CodeFormer/.gitignore
@@ -0,0 +1,129 @@
+.vscode
+
+# ignored files
+version.py
+
+# ignored files with suffix
+*.html
+# *.png
+# *.jpeg
+# *.jpg
+*.pt
+*.gif
+*.pth
+*.dat
+*.zip
+
+# template
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# project
+results/
+dlib/
+*.pth
+*_old*
+
diff --git a/CodeFormer/README.md b/CodeFormer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..65810cdf4ce36d8ba152de80df00fa4c8802ee81
--- /dev/null
+++ b/CodeFormer/README.md
@@ -0,0 +1,123 @@
+<p align="center">
+  <img src="assets/CodeFormer_logo.png" height=110>
+</p>
+
+## Towards Robust Blind Face Restoration with Codebook Lookup Transformer
+
+[Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
+
+
+<a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) ![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
+
+[Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) 
+
+S-Lab, Nanyang Technological University
+
+<img src="assets/network.jpg" width="800px"/>
+
+
+:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs: 
+
+### Update
+
+- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
+- **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
+- **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
+- **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement.
+- **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`. 
+- **2022.07.17**: Add Colab demo of CodeFormer. <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
+- **2022.07.16**: Release inference code for face restoration. :blush:
+- **2022.06.21**: This repo is created.
+
+### TODO
+- [ ] Add checkpoint for face inpainting
+- [ ] Add training code and config files
+- [x] ~~Add background image enhancement~~
+
+#### Face Restoration
+
+<img src="assets/restoration_result1.png" width="400px"/> <img src="assets/restoration_result2.png" width="400px"/>
+<img src="assets/restoration_result3.png" width="400px"/> <img src="assets/restoration_result4.png" width="400px"/>
+
+#### Face Color Enhancement and Restoration
+
+<img src="assets/color_enhancement_result1.png" width="400px"/> <img src="assets/color_enhancement_result2.png" width="400px"/>
+
+#### Face Inpainting
+
+<img src="assets/inpainting_result1.png" width="400px"/> <img src="assets/inpainting_result2.png" width="400px"/>
+
+
+
+### Dependencies and Installation
+
+- Pytorch >= 1.7.1
+- CUDA >= 10.1
+- Other required packages in `requirements.txt`
+```
+# git clone this repository
+git clone https://github.com/sczhou/CodeFormer
+cd CodeFormer
+
+# create new anaconda env
+conda create -n codeformer python=3.8 -y
+conda activate codeformer
+
+# install python dependencies
+pip3 install -r requirements.txt
+python basicsr/setup.py develop
+```
+<!-- conda install -c conda-forge dlib -->
+
+### Quick Inference
+
+##### Download Pre-trained Models:
+Download the facelib pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by runing the following command.
+```
+python scripts/download_pretrained_models.py facelib
+```
+
+Download the CodeFormer pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by runing the following command.
+```
+python scripts/download_pretrained_models.py CodeFormer
+```
+
+##### Prepare Testing Data:
+You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder.
+
+
+##### Testing on Face Restoration:
+```
+# For cropped and aligned faces
+python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
+
+# For the whole images
+# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
+# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
+python inference_codeformer.py --w 0.7 --test_path [input folder]
+```
+
+NOTE that *w* is in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result. 
+
+The results will be saved in the `results` folder.
+
+### Citation
+If our work is useful for your research, please consider citing:
+
+    @article{zhou2022codeformer,
+        author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
+        title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
+        journal = {arXiv preprint arXiv:2206.11253},
+        year = {2022}
+    }
+
+### License
+
+<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.
+
+### Acknowledgement
+
+This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). We also borrow some codes from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). Thanks for their awesome works.
+
+### Contact
+If you have any question, please feel free to reach me out at `shangchenzhou@gmail.com`.
\ No newline at end of file
diff --git a/CodeFormer/assets/CodeFormer_logo.png b/CodeFormer/assets/CodeFormer_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..024cb724f43c2b5cff7039c69b78f261a5a4898c
Binary files /dev/null and b/CodeFormer/assets/CodeFormer_logo.png differ
diff --git a/CodeFormer/assets/color_enhancement_result1.png b/CodeFormer/assets/color_enhancement_result1.png
new file mode 100644
index 0000000000000000000000000000000000000000..34433db6378b37cb47a1e544217e4d7f679f7038
Binary files /dev/null and b/CodeFormer/assets/color_enhancement_result1.png differ
diff --git a/CodeFormer/assets/color_enhancement_result2.png b/CodeFormer/assets/color_enhancement_result2.png
new file mode 100644
index 0000000000000000000000000000000000000000..228690ac9b1453e67e0212ab2952bea887543a09
Binary files /dev/null and b/CodeFormer/assets/color_enhancement_result2.png differ
diff --git a/CodeFormer/assets/inpainting_result1.png b/CodeFormer/assets/inpainting_result1.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c6fa68ad4340c0281e096f7928d28be831c00c1
Binary files /dev/null and b/CodeFormer/assets/inpainting_result1.png differ
diff --git a/CodeFormer/assets/inpainting_result2.png b/CodeFormer/assets/inpainting_result2.png
new file mode 100644
index 0000000000000000000000000000000000000000..2945f9f91c93c329c5e66d4e8519dbb3f90fa1b5
Binary files /dev/null and b/CodeFormer/assets/inpainting_result2.png differ
diff --git a/CodeFormer/assets/network.jpg b/CodeFormer/assets/network.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5aaa6bd1b0f71bf28e5f175c2cda1e7b34b8aa5f
Binary files /dev/null and b/CodeFormer/assets/network.jpg differ
diff --git a/CodeFormer/assets/restoration_result1.png b/CodeFormer/assets/restoration_result1.png
new file mode 100644
index 0000000000000000000000000000000000000000..8fd3b67ec9a5c9b7606ea0515a5b071c1e7a1118
Binary files /dev/null and b/CodeFormer/assets/restoration_result1.png differ
diff --git a/CodeFormer/assets/restoration_result2.png b/CodeFormer/assets/restoration_result2.png
new file mode 100644
index 0000000000000000000000000000000000000000..a2ff282701b6c66a612b3b669512e8d99595ee9f
Binary files /dev/null and b/CodeFormer/assets/restoration_result2.png differ
diff --git a/CodeFormer/assets/restoration_result3.png b/CodeFormer/assets/restoration_result3.png
new file mode 100644
index 0000000000000000000000000000000000000000..022d764266b4d43f4ffea6b1f7ccca63b32e180c
Binary files /dev/null and b/CodeFormer/assets/restoration_result3.png differ
diff --git a/CodeFormer/assets/restoration_result4.png b/CodeFormer/assets/restoration_result4.png
new file mode 100644
index 0000000000000000000000000000000000000000..5e965076c7b5fae051dc2df354f74c0864ec4214
Binary files /dev/null and b/CodeFormer/assets/restoration_result4.png differ
diff --git a/CodeFormer/basicsr/.DS_Store b/CodeFormer/basicsr/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..dd20dcb38f40adde1a9266896e27431765a096cf
Binary files /dev/null and b/CodeFormer/basicsr/.DS_Store differ
diff --git a/CodeFormer/basicsr/VERSION b/CodeFormer/basicsr/VERSION
new file mode 100644
index 0000000000000000000000000000000000000000..1892b926767774e9ba91f1e584fa71b4c56abb69
--- /dev/null
+++ b/CodeFormer/basicsr/VERSION
@@ -0,0 +1 @@
+1.3.2
diff --git a/CodeFormer/basicsr/__init__.py b/CodeFormer/basicsr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7ffcccd7fc0f33b59d99d73d0436d60e561b0fc
--- /dev/null
+++ b/CodeFormer/basicsr/__init__.py
@@ -0,0 +1,11 @@
+# https://github.com/xinntao/BasicSR
+# flake8: noqa
+from .archs import *
+from .data import *
+from .losses import *
+from .metrics import *
+from .models import *
+from .ops import *
+from .train import *
+from .utils import *
+from .version import __gitsha__, __version__
diff --git a/CodeFormer/basicsr/archs/__init__.py b/CodeFormer/basicsr/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfb1e4d7bb221c429082bd389d9140e5b1cc07b0
--- /dev/null
+++ b/CodeFormer/basicsr/archs/__init__.py
@@ -0,0 +1,25 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ['build_network']
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with
+# '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
+
+
+def build_network(opt):
+    opt = deepcopy(opt)
+    network_type = opt.pop('type')
+    net = ARCH_REGISTRY.get(network_type)(**opt)
+    logger = get_root_logger()
+    logger.info(f'Network [{net.__class__.__name__}] is created.')
+    return net
diff --git a/CodeFormer/basicsr/archs/arcface_arch.py b/CodeFormer/basicsr/archs/arcface_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe5afb7bd2b359e0c2b7efdf628ab10b63964d87
--- /dev/null
+++ b/CodeFormer/basicsr/archs/arcface_arch.py
@@ -0,0 +1,245 @@
+import torch.nn as nn
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+def conv3x3(inplanes, outplanes, stride=1):
+    """A simple wrapper for 3x3 convolution with padding.
+
+    Args:
+        inplanes (int): Channel number of inputs.
+        outplanes (int): Channel number of outputs.
+        stride (int): Stride in convolution. Default: 1.
+    """
+    return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+    """Basic residual block used in the ResNetArcFace architecture.
+
+    Args:
+        inplanes (int): Channel number of inputs.
+        planes (int): Channel number of outputs.
+        stride (int): Stride in convolution. Default: 1.
+        downsample (nn.Module): The downsample module. Default: None.
+    """
+    expansion = 1  # output channel expansion ratio
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class IRBlock(nn.Module):
+    """Improved residual block (IR Block) used in the ResNetArcFace architecture.
+
+    Args:
+        inplanes (int): Channel number of inputs.
+        planes (int): Channel number of outputs.
+        stride (int): Stride in convolution. Default: 1.
+        downsample (nn.Module): The downsample module. Default: None.
+        use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+    """
+    expansion = 1  # output channel expansion ratio
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
+        super(IRBlock, self).__init__()
+        self.bn0 = nn.BatchNorm2d(inplanes)
+        self.conv1 = conv3x3(inplanes, inplanes)
+        self.bn1 = nn.BatchNorm2d(inplanes)
+        self.prelu = nn.PReLU()
+        self.conv2 = conv3x3(inplanes, planes, stride)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+        self.use_se = use_se
+        if self.use_se:
+            self.se = SEBlock(planes)
+
+    def forward(self, x):
+        residual = x
+        out = self.bn0(x)
+        out = self.conv1(out)
+        out = self.bn1(out)
+        out = self.prelu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        if self.use_se:
+            out = self.se(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.prelu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    """Bottleneck block used in the ResNetArcFace architecture.
+
+    Args:
+        inplanes (int): Channel number of inputs.
+        planes (int): Channel number of outputs.
+        stride (int): Stride in convolution. Default: 1.
+        downsample (nn.Module): The downsample module. Default: None.
+    """
+    expansion = 4  # output channel expansion ratio
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class SEBlock(nn.Module):
+    """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
+
+    Args:
+        channel (int): Channel number of inputs.
+        reduction (int): Channel reduction ration. Default: 16.
+    """
+
+    def __init__(self, channel, reduction=16):
+        super(SEBlock, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # pool to 1x1 without spatial information
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
+            nn.Sigmoid())
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y
+
+
+@ARCH_REGISTRY.register()
+class ResNetArcFace(nn.Module):
+    """ArcFace with ResNet architectures.
+
+    Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
+
+    Args:
+        block (str): Block used in the ArcFace architecture.
+        layers (tuple(int)): Block numbers in each layer.
+        use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+    """
+
+    def __init__(self, block, layers, use_se=True):
+        if block == 'IRBlock':
+            block = IRBlock
+        self.inplanes = 64
+        self.use_se = use_se
+        super(ResNetArcFace, self).__init__()
+
+        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.prelu = nn.PReLU()
+        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.bn4 = nn.BatchNorm2d(512)
+        self.dropout = nn.Dropout()
+        self.fc5 = nn.Linear(512 * 8 * 8, 512)
+        self.bn5 = nn.BatchNorm1d(512)
+
+        # initialization
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.xavier_normal_(m.weight)
+            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.xavier_normal_(m.weight)
+                nn.init.constant_(m.bias, 0)
+
+    def _make_layer(self, block, planes, num_blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
+        self.inplanes = planes
+        for _ in range(1, num_blocks):
+            layers.append(block(self.inplanes, planes, use_se=self.use_se))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.prelu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.bn4(x)
+        x = self.dropout(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc5(x)
+        x = self.bn5(x)
+
+        return x
\ No newline at end of file
diff --git a/CodeFormer/basicsr/archs/arch_util.py b/CodeFormer/basicsr/archs/arch_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad45ab34e901c47fb539152fca714a3795b0de2
--- /dev/null
+++ b/CodeFormer/basicsr/archs/arch_util.py
@@ -0,0 +1,318 @@
+import collections.abc
+import math
+import torch
+import torchvision
+import warnings
+from distutils.version import LooseVersion
+from itertools import repeat
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+    """Initialize network weights.
+
+    Args:
+        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+        scale (float): Scale initialized weights, especially for residual
+            blocks. Default: 1.
+        bias_fill (float): The value to fill bias. Default: 0
+        kwargs (dict): Other arguments for initialization function.
+    """
+    if not isinstance(module_list, list):
+        module_list = [module_list]
+    for module in module_list:
+        for m in module.modules():
+            if isinstance(m, nn.Conv2d):
+                init.kaiming_normal_(m.weight, **kwargs)
+                m.weight.data *= scale
+                if m.bias is not None:
+                    m.bias.data.fill_(bias_fill)
+            elif isinstance(m, nn.Linear):
+                init.kaiming_normal_(m.weight, **kwargs)
+                m.weight.data *= scale
+                if m.bias is not None:
+                    m.bias.data.fill_(bias_fill)
+            elif isinstance(m, _BatchNorm):
+                init.constant_(m.weight, 1)
+                if m.bias is not None:
+                    m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+    """Make layers by stacking the same blocks.
+
+    Args:
+        basic_block (nn.module): nn.module class for basic block.
+        num_basic_block (int): number of blocks.
+
+    Returns:
+        nn.Sequential: Stacked blocks in nn.Sequential.
+    """
+    layers = []
+    for _ in range(num_basic_block):
+        layers.append(basic_block(**kwarg))
+    return nn.Sequential(*layers)
+
+
+class ResidualBlockNoBN(nn.Module):
+    """Residual block without BN.
+
+    It has a style of:
+        ---Conv-ReLU-Conv-+-
+         |________________|
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+            Default: 64.
+        res_scale (float): Residual scale. Default: 1.
+        pytorch_init (bool): If set to True, use pytorch default init,
+            otherwise, use default_init_weights. Default: False.
+    """
+
+    def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+        super(ResidualBlockNoBN, self).__init__()
+        self.res_scale = res_scale
+        self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+        self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+        self.relu = nn.ReLU(inplace=True)
+
+        if not pytorch_init:
+            default_init_weights([self.conv1, self.conv2], 0.1)
+
+    def forward(self, x):
+        identity = x
+        out = self.conv2(self.relu(self.conv1(x)))
+        return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+    """Upsample module.
+
+    Args:
+        scale (int): Scale factor. Supported scales: 2^n and 3.
+        num_feat (int): Channel number of intermediate features.
+    """
+
+    def __init__(self, scale, num_feat):
+        m = []
+        if (scale & (scale - 1)) == 0:  # scale = 2^n
+            for _ in range(int(math.log(scale, 2))):
+                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+                m.append(nn.PixelShuffle(2))
+        elif scale == 3:
+            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+            m.append(nn.PixelShuffle(3))
+        else:
+            raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+        super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
+    """Warp an image or feature map with optical flow.
+
+    Args:
+        x (Tensor): Tensor with size (n, c, h, w).
+        flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+        interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+        padding_mode (str): 'zeros' or 'border' or 'reflection'.
+            Default: 'zeros'.
+        align_corners (bool): Before pytorch 1.3, the default value is
+            align_corners=True. After pytorch 1.3, the default value is
+            align_corners=False. Here, we use the True as default.
+
+    Returns:
+        Tensor: Warped image or feature map.
+    """
+    assert x.size()[-2:] == flow.size()[1:3]
+    _, _, h, w = x.size()
+    # create mesh grid
+    grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
+    grid.requires_grad = False
+
+    vgrid = grid + flow
+    # scale grid to [-1,1]
+    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+    output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+    # TODO, what if align_corners=False
+    return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
+    """Resize a flow according to ratio or shape.
+
+    Args:
+        flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+        size_type (str): 'ratio' or 'shape'.
+        sizes (list[int | float]): the ratio for resizing or the final output
+            shape.
+            1) The order of ratio should be [ratio_h, ratio_w]. For
+            downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+            < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+            ratio > 1.0).
+            2) The order of output_size should be [out_h, out_w].
+        interp_mode (str): The mode of interpolation for resizing.
+            Default: 'bilinear'.
+        align_corners (bool): Whether align corners. Default: False.
+
+    Returns:
+        Tensor: Resized flow.
+    """
+    _, _, flow_h, flow_w = flow.size()
+    if size_type == 'ratio':
+        output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+    elif size_type == 'shape':
+        output_h, output_w = sizes[0], sizes[1]
+    else:
+        raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
+
+    input_flow = flow.clone()
+    ratio_h = output_h / flow_h
+    ratio_w = output_w / flow_w
+    input_flow[:, 0, :, :] *= ratio_w
+    input_flow[:, 1, :, :] *= ratio_h
+    resized_flow = F.interpolate(
+        input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+    return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+    """ Pixel unshuffle.
+
+    Args:
+        x (Tensor): Input feature with shape (b, c, hh, hw).
+        scale (int): Downsample ratio.
+
+    Returns:
+        Tensor: the pixel unshuffled feature.
+    """
+    b, c, hh, hw = x.size()
+    out_channel = c * (scale**2)
+    assert hh % scale == 0 and hw % scale == 0
+    h = hh // scale
+    w = hw // scale
+    x_view = x.view(b, c, h, scale, w, scale)
+    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+    """Modulated deformable conv for deformable alignment.
+
+    Different from the official DCNv2Pack, which generates offsets and masks
+    from the preceding features, this DCNv2Pack takes another different
+    features to generate offsets and masks.
+
+    Ref:
+        Delving Deep into Deformable Alignment in Video Super-Resolution.
+    """
+
+    def forward(self, x, feat):
+        out = self.conv_offset(feat)
+        o1, o2, mask = torch.chunk(out, 3, dim=1)
+        offset = torch.cat((o1, o2), dim=1)
+        mask = torch.sigmoid(mask)
+
+        offset_absmean = torch.mean(torch.abs(offset))
+        if offset_absmean > 50:
+            logger = get_root_logger()
+            logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+
+        if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+            return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+                                                 self.dilation, mask)
+        else:
+            return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+                                         self.dilation, self.groups, self.deformable_groups)
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+    # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+    # Cut & paste from PyTorch official master until it's in a few official releases - RW
+    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn(
+            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+            'The distribution of values may be incorrect.',
+            stacklevel=2)
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        low = norm_cdf((a - mean) / std)
+        up = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [low, up], then translate to
+        # [2l-1, 2u-1].
+        tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+    r"""Fills the input Tensor with values drawn from a truncated
+    normal distribution.
+
+    From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+    The values are effectively drawn from the
+    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \leq \text{mean} \leq b`.
+
+    Args:
+        tensor: an n-dimensional `torch.Tensor`
+        mean: the mean of the normal distribution
+        std: the standard deviation of the normal distribution
+        a: the minimum cutoff value
+        b: the maximum cutoff value
+
+    Examples:
+        >>> w = torch.empty(3, 5)
+        >>> nn.init.trunc_normal_(w)
+    """
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+
+    return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
\ No newline at end of file
diff --git a/CodeFormer/basicsr/archs/codeformer_arch.py b/CodeFormer/basicsr/archs/codeformer_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d0d8027c8c4ffb26af6f4ba361514e93e320e8d
--- /dev/null
+++ b/CodeFormer/basicsr/archs/codeformer_arch.py
@@ -0,0 +1,276 @@
+import math
+import numpy as np
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from typing import Optional, List
+
+from basicsr.archs.vqgan_arch import *
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def calc_mean_std(feat, eps=1e-5):
+    """Calculate mean and std for adaptive_instance_normalization.
+
+    Args:
+        feat (Tensor): 4D tensor.
+        eps (float): A small value added to the variance to avoid
+            divide-by-zero. Default: 1e-5.
+    """
+    size = feat.size()
+    assert len(size) == 4, 'The input feature should be 4D tensor.'
+    b, c = size[:2]
+    feat_var = feat.view(b, c, -1).var(dim=2) + eps
+    feat_std = feat_var.sqrt().view(b, c, 1, 1)
+    feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+    return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+    """Adaptive instance normalization.
+
+    Adjust the reference features to have the similar color and illuminations
+    as those in the degradate features.
+
+    Args:
+        content_feat (Tensor): The reference feature.
+        style_feat (Tensor): The degradate features.
+    """
+    size = content_feat.size()
+    style_mean, style_std = calc_mean_std(style_feat)
+    content_mean, content_std = calc_mean_std(content_feat)
+    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+    return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+    """
+    This is a more standard version of the position embedding, very similar to the one
+    used by the Attention is all you need paper, generalized to work on images.
+    """
+
+    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+        super().__init__()
+        self.num_pos_feats = num_pos_feats
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * math.pi
+        self.scale = scale
+
+    def forward(self, x, mask=None):
+        if mask is None:
+            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+        not_mask = ~mask
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        pos_x = torch.stack(
+            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos_y = torch.stack(
+            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        return pos
+
+def _get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+    def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+        # Implementation of Feedforward model - MLP
+        self.linear1 = nn.Linear(embed_dim, dim_mlp)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+        self.norm1 = nn.LayerNorm(embed_dim)
+        self.norm2 = nn.LayerNorm(embed_dim)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward(self, tgt,
+                tgt_mask: Optional[Tensor] = None,
+                tgt_key_padding_mask: Optional[Tensor] = None,
+                query_pos: Optional[Tensor] = None):
+        
+        # self attention
+        tgt2 = self.norm1(tgt)
+        q = k = self.with_pos_embed(tgt2, query_pos)
+        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout1(tgt2)
+
+        # ffn
+        tgt2 = self.norm2(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout2(tgt2)
+        return tgt
+
+class Fuse_sft_block(nn.Module):
+    def __init__(self, in_ch, out_ch):
+        super().__init__()
+        self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+        self.scale = nn.Sequential(
+                    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+                    nn.LeakyReLU(0.2, True),
+                    nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+        self.shift = nn.Sequential(
+                    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+                    nn.LeakyReLU(0.2, True),
+                    nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+    def forward(self, enc_feat, dec_feat, w=1):
+        enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+        scale = self.scale(enc_feat)
+        shift = self.shift(enc_feat)
+        residual = w * (dec_feat * scale + shift)
+        out = dec_feat + residual
+        return out
+
+
+@ARCH_REGISTRY.register()
+class CodeFormer(VQAutoEncoder):
+    def __init__(self, dim_embd=512, n_head=8, n_layers=9, 
+                codebook_size=1024, latent_size=256,
+                connect_list=['32', '64', '128', '256'],
+                fix_modules=['quantize','generator']):
+        super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+
+        if fix_modules is not None:
+            for module in fix_modules:
+                for param in getattr(self, module).parameters():
+                    param.requires_grad = False
+
+        self.connect_list = connect_list
+        self.n_layers = n_layers
+        self.dim_embd = dim_embd
+        self.dim_mlp = dim_embd*2
+
+        self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+        self.feat_emb = nn.Linear(256, self.dim_embd)
+
+        # transformer
+        self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) 
+                                    for _ in range(self.n_layers)])
+
+        # logits_predict head
+        self.idx_pred_layer = nn.Sequential(
+            nn.LayerNorm(dim_embd),
+            nn.Linear(dim_embd, codebook_size, bias=False))
+        
+        self.channels = {
+            '16': 512,
+            '32': 256,
+            '64': 256,
+            '128': 128,
+            '256': 128,
+            '512': 64,
+        }
+
+        # after second residual block for > 16, before attn layer for ==16
+        self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
+        # after first residual block for > 16, before attn layer for ==16
+        self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
+
+        # fuse_convs_dict
+        self.fuse_convs_dict = nn.ModuleDict()
+        for f_size in self.connect_list:
+            in_ch = self.channels[f_size]
+            self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+    def _init_weights(self, module):
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            module.weight.data.normal_(mean=0.0, std=0.02)
+            if isinstance(module, nn.Linear) and module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+    def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+        # ################### Encoder #####################
+        enc_feat_dict = {}
+        out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+        for i, block in enumerate(self.encoder.blocks):
+            x = block(x) 
+            if i in out_list:
+                enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+        lq_feat = x
+        # ################# Transformer ###################
+        # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+        pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+        # BCHW -> BC(HW) -> (HW)BC
+        feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+        query_emb = feat_emb
+        # Transformer encoder
+        for layer in self.ft_layers:
+            query_emb = layer(query_emb, query_pos=pos_emb)
+
+        # output logits
+        logits = self.idx_pred_layer(query_emb) # (hw)bn
+        logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+        if code_only: # for training stage II
+          # logits doesn't need softmax before cross_entropy loss
+            return logits, lq_feat
+
+        # ################# Quantization ###################
+        # if self.training:
+        #     quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+        #     # b(hw)c -> bc(hw) -> bchw
+        #     quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+        # ------------
+        soft_one_hot = F.softmax(logits, dim=2)
+        _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+        quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+        # preserve gradients
+        # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+        if detach_16:
+            quant_feat = quant_feat.detach() # for training stage III
+        if adain:
+            quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+        # ################## Generator ####################
+        x = quant_feat
+        fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+        for i, block in enumerate(self.generator.blocks):
+            x = block(x) 
+            if i in fuse_list: # fuse after i-th block
+                f_size = str(x.shape[-1])
+                if w>0:
+                    x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+        out = x
+        # logits doesn't need softmax before cross_entropy loss
+        return out, logits, lq_feat
\ No newline at end of file
diff --git a/CodeFormer/basicsr/archs/rrdbnet_arch.py b/CodeFormer/basicsr/archs/rrdbnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..49a2d6c204557cba53ada7550deb587541855cfb
--- /dev/null
+++ b/CodeFormer/basicsr/archs/rrdbnet_arch.py
@@ -0,0 +1,119 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+    """Residual Dense Block.
+
+    Used in RRDB block in ESRGAN.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+        num_grow_ch (int): Channels for each growth.
+    """
+
+    def __init__(self, num_feat=64, num_grow_ch=32):
+        super(ResidualDenseBlock, self).__init__()
+        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+        # initialization
+        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+    def forward(self, x):
+        x1 = self.lrelu(self.conv1(x))
+        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+        # Emperically, we use 0.2 to scale the residual for better performance
+        return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+    """Residual in Residual Dense Block.
+
+    Used in RRDB-Net in ESRGAN.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+        num_grow_ch (int): Channels for each growth.
+    """
+
+    def __init__(self, num_feat, num_grow_ch=32):
+        super(RRDB, self).__init__()
+        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+    def forward(self, x):
+        out = self.rdb1(x)
+        out = self.rdb2(out)
+        out = self.rdb3(out)
+        # Emperically, we use 0.2 to scale the residual for better performance
+        return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+    """Networks consisting of Residual in Residual Dense Block, which is used
+    in ESRGAN.
+
+    ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+    We extend ESRGAN for scale x2 and scale x1.
+    Note: This is one option for scale 1, scale 2 in RRDBNet.
+    We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+    and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+    Args:
+        num_in_ch (int): Channel number of inputs.
+        num_out_ch (int): Channel number of outputs.
+        num_feat (int): Channel number of intermediate features.
+            Default: 64
+        num_block (int): Block number in the trunk network. Defaults: 23
+        num_grow_ch (int): Channels for each growth. Default: 32.
+    """
+
+    def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+        super(RRDBNet, self).__init__()
+        self.scale = scale
+        if scale == 2:
+            num_in_ch = num_in_ch * 4
+        elif scale == 1:
+            num_in_ch = num_in_ch * 16
+        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+        self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        # upsample
+        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+    def forward(self, x):
+        if self.scale == 2:
+            feat = pixel_unshuffle(x, scale=2)
+        elif self.scale == 1:
+            feat = pixel_unshuffle(x, scale=4)
+        else:
+            feat = x
+        feat = self.conv_first(feat)
+        body_feat = self.conv_body(self.body(feat))
+        feat = feat + body_feat
+        # upsample
+        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
+        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
+        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+        return out
\ No newline at end of file
diff --git a/CodeFormer/basicsr/archs/vgg_arch.py b/CodeFormer/basicsr/archs/vgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..23bb0103c8b14ef2588028f7177753db9af62cae
--- /dev/null
+++ b/CodeFormer/basicsr/archs/vgg_arch.py
@@ -0,0 +1,161 @@
+import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+    'vgg11': [
+        'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+        'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+        'pool5'
+    ],
+    'vgg13': [
+        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+    ],
+    'vgg16': [
+        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+        'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+        'pool5'
+    ],
+    'vgg19': [
+        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+        'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+        'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+    ]
+}
+
+
+def insert_bn(names):
+    """Insert bn layer after each conv.
+
+    Args:
+        names (list): The list of layer names.
+
+    Returns:
+        list: The list of layer names with bn layers.
+    """
+    names_bn = []
+    for name in names:
+        names_bn.append(name)
+        if 'conv' in name:
+            position = name.replace('conv', '')
+            names_bn.append('bn' + position)
+    return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+    """VGG network for feature extraction.
+
+    In this implementation, we allow users to choose whether use normalization
+    in the input feature and the type of vgg network. Note that the pretrained
+    path must fit the vgg type.
+
+    Args:
+        layer_name_list (list[str]): Forward function returns the corresponding
+            features according to the layer_name_list.
+            Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+        vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+        use_input_norm (bool): If True, normalize the input image. Importantly,
+            the input feature must in the range [0, 1]. Default: True.
+        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+            Default: False.
+        requires_grad (bool): If true, the parameters of VGG network will be
+            optimized. Default: False.
+        remove_pooling (bool): If true, the max pooling operations in VGG net
+            will be removed. Default: False.
+        pooling_stride (int): The stride of max pooling operation. Default: 2.
+    """
+
+    def __init__(self,
+                 layer_name_list,
+                 vgg_type='vgg19',
+                 use_input_norm=True,
+                 range_norm=False,
+                 requires_grad=False,
+                 remove_pooling=False,
+                 pooling_stride=2):
+        super(VGGFeatureExtractor, self).__init__()
+
+        self.layer_name_list = layer_name_list
+        self.use_input_norm = use_input_norm
+        self.range_norm = range_norm
+
+        self.names = NAMES[vgg_type.replace('_bn', '')]
+        if 'bn' in vgg_type:
+            self.names = insert_bn(self.names)
+
+        # only borrow layers that will be used to avoid unused params
+        max_idx = 0
+        for v in layer_name_list:
+            idx = self.names.index(v)
+            if idx > max_idx:
+                max_idx = idx
+
+        if os.path.exists(VGG_PRETRAIN_PATH):
+            vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+            state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+            vgg_net.load_state_dict(state_dict)
+        else:
+            vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+        features = vgg_net.features[:max_idx + 1]
+
+        modified_net = OrderedDict()
+        for k, v in zip(self.names, features):
+            if 'pool' in k:
+                # if remove_pooling is true, pooling operation will be removed
+                if remove_pooling:
+                    continue
+                else:
+                    # in some cases, we may want to change the default stride
+                    modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+            else:
+                modified_net[k] = v
+
+        self.vgg_net = nn.Sequential(modified_net)
+
+        if not requires_grad:
+            self.vgg_net.eval()
+            for param in self.parameters():
+                param.requires_grad = False
+        else:
+            self.vgg_net.train()
+            for param in self.parameters():
+                param.requires_grad = True
+
+        if self.use_input_norm:
+            # the mean is for image with range [0, 1]
+            self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+            # the std is for image with range [0, 1]
+            self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+    def forward(self, x):
+        """Forward function.
+
+        Args:
+            x (Tensor): Input tensor with shape (n, c, h, w).
+
+        Returns:
+            Tensor: Forward results.
+        """
+        if self.range_norm:
+            x = (x + 1) / 2
+        if self.use_input_norm:
+            x = (x - self.mean) / self.std
+        output = {}
+
+        for key, layer in self.vgg_net._modules.items():
+            x = layer(x)
+            if key in self.layer_name_list:
+                output[key] = x.clone()
+
+        return output
diff --git a/CodeFormer/basicsr/archs/vqgan_arch.py b/CodeFormer/basicsr/archs/vqgan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6dfcf4c9983b431f0a978701e5ddd9598faf381
--- /dev/null
+++ b/CodeFormer/basicsr/archs/vqgan_arch.py
@@ -0,0 +1,435 @@
+'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def normalize(in_channels):
+    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+    
+
+@torch.jit.script
+def swish(x):
+    return x*torch.sigmoid(x)
+
+
+#  Define VQVAE classes
+class VectorQuantizer(nn.Module):
+    def __init__(self, codebook_size, emb_dim, beta):
+        super(VectorQuantizer, self).__init__()
+        self.codebook_size = codebook_size  # number of embeddings
+        self.emb_dim = emb_dim  # dimension of embedding
+        self.beta = beta  # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+        self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+        self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+    def forward(self, z):
+        # reshape z -> (batch, height, width, channel) and flatten
+        z = z.permute(0, 2, 3, 1).contiguous()
+        z_flattened = z.view(-1, self.emb_dim)
+
+        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+        d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+            2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+        mean_distance = torch.mean(d)
+        # find closest encodings
+        # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+        min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+        # [0-1], higher score, higher confidence
+        min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+        min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
+        min_encodings.scatter_(1, min_encoding_indices, 1)
+
+        # get quantized latent vectors
+        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+        # compute loss for embedding
+        loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+        # preserve gradients
+        z_q = z + (z_q - z).detach()
+
+        # perplexity
+        e_mean = torch.mean(min_encodings, dim=0)
+        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+        # reshape back to match original input shape
+        z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+        return z_q, loss, {
+            "perplexity": perplexity,
+            "min_encodings": min_encodings,
+            "min_encoding_indices": min_encoding_indices,
+            "min_encoding_scores": min_encoding_scores,
+            "mean_distance": mean_distance
+            }
+
+    def get_codebook_feat(self, indices, shape):
+        # input indices: batch*token_num -> (batch*token_num)*1
+        # shape: batch, height, width, channel
+        indices = indices.view(-1,1)
+        min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+        min_encodings.scatter_(1, indices, 1)
+        # get quantized latent vectors
+        z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+        if shape is not None:  # reshape back to match original input shape
+            z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+        return z_q
+
+
+class GumbelQuantizer(nn.Module):
+    def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+        super().__init__()
+        self.codebook_size = codebook_size  # number of embeddings
+        self.emb_dim = emb_dim  # dimension of embedding
+        self.straight_through = straight_through
+        self.temperature = temp_init
+        self.kl_weight = kl_weight
+        self.proj = nn.Conv2d(num_hiddens, codebook_size, 1)  # projects last encoder layer to quantized logits
+        self.embed = nn.Embedding(codebook_size, emb_dim)
+
+    def forward(self, z):
+        hard = self.straight_through if self.training else True
+
+        logits = self.proj(z)
+
+        soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+        z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+        # + kl divergence to the prior loss
+        qy = F.softmax(logits, dim=1)
+        diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+        min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+        return z_q, diff, {
+            "min_encoding_indices": min_encoding_indices
+        }
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+    def forward(self, x):
+        pad = (0, 1, 0, 1)
+        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+        x = self.conv(x)
+        return x
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+    def forward(self, x):
+        x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+        x = self.conv(x)
+
+        return x
+
+
+class ResBlock(nn.Module):
+    def __init__(self, in_channels, out_channels=None):
+        super(ResBlock, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = in_channels if out_channels is None else out_channels
+        self.norm1 = normalize(in_channels)
+        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+        self.norm2 = normalize(out_channels)
+        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+        if self.in_channels != self.out_channels:
+            self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, x_in):
+        x = x_in
+        x = self.norm1(x)
+        x = swish(x)
+        x = self.conv1(x)
+        x = self.norm2(x)
+        x = swish(x)
+        x = self.conv2(x)
+        if self.in_channels != self.out_channels:
+            x_in = self.conv_out(x_in)
+
+        return x + x_in
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = normalize(in_channels)
+        self.q = torch.nn.Conv2d(
+            in_channels,
+            in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0
+        )
+        self.k = torch.nn.Conv2d(
+            in_channels,
+            in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0
+        )
+        self.v = torch.nn.Conv2d(
+            in_channels,
+            in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0
+        )
+        self.proj_out = torch.nn.Conv2d(
+            in_channels,
+            in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0
+        )
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b, c, h, w = q.shape
+        q = q.reshape(b, c, h*w)
+        q = q.permute(0, 2, 1)   
+        k = k.reshape(b, c, h*w)
+        w_ = torch.bmm(q, k) 
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = F.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b, c, h*w)
+        w_ = w_.permute(0, 2, 1) 
+        h_ = torch.bmm(v, w_)
+        h_ = h_.reshape(b, c, h, w)
+
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+
+class Encoder(nn.Module):
+    def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+        super().__init__()
+        self.nf = nf
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.attn_resolutions = attn_resolutions
+
+        curr_res = self.resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+
+        blocks = []
+        # initial convultion
+        blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+        # residual and downsampling blocks, with attention on smaller res (16x16)
+        for i in range(self.num_resolutions):
+            block_in_ch = nf * in_ch_mult[i]
+            block_out_ch = nf * ch_mult[i]
+            for _ in range(self.num_res_blocks):
+                blocks.append(ResBlock(block_in_ch, block_out_ch))
+                block_in_ch = block_out_ch
+                if curr_res in attn_resolutions:
+                    blocks.append(AttnBlock(block_in_ch))
+
+            if i != self.num_resolutions - 1:
+                blocks.append(Downsample(block_in_ch))
+                curr_res = curr_res // 2
+
+        # non-local attention block
+        blocks.append(ResBlock(block_in_ch, block_in_ch))
+        blocks.append(AttnBlock(block_in_ch))
+        blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+        # normalise and convert to latent size
+        blocks.append(normalize(block_in_ch))
+        blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+        self.blocks = nn.ModuleList(blocks)
+
+    def forward(self, x):
+        for block in self.blocks:
+            x = block(x)
+            
+        return x
+
+
+class Generator(nn.Module):
+    def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+        super().__init__()
+        self.nf = nf 
+        self.ch_mult = ch_mult 
+        self.num_resolutions = len(self.ch_mult)
+        self.num_res_blocks = res_blocks
+        self.resolution = img_size 
+        self.attn_resolutions = attn_resolutions
+        self.in_channels = emb_dim
+        self.out_channels = 3
+        block_in_ch = self.nf * self.ch_mult[-1]
+        curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+        blocks = []
+        # initial conv
+        blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+        # non-local attention block
+        blocks.append(ResBlock(block_in_ch, block_in_ch))
+        blocks.append(AttnBlock(block_in_ch))
+        blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+        for i in reversed(range(self.num_resolutions)):
+            block_out_ch = self.nf * self.ch_mult[i]
+
+            for _ in range(self.num_res_blocks):
+                blocks.append(ResBlock(block_in_ch, block_out_ch))
+                block_in_ch = block_out_ch
+
+                if curr_res in self.attn_resolutions:
+                    blocks.append(AttnBlock(block_in_ch))
+
+            if i != 0:
+                blocks.append(Upsample(block_in_ch))
+                curr_res = curr_res * 2
+
+        blocks.append(normalize(block_in_ch))
+        blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+        self.blocks = nn.ModuleList(blocks)
+   
+
+    def forward(self, x):
+        for block in self.blocks:
+            x = block(x)
+            
+        return x
+
+  
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+    def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+                beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+        super().__init__()
+        logger = get_root_logger()
+        self.in_channels = 3 
+        self.nf = nf 
+        self.n_blocks = res_blocks 
+        self.codebook_size = codebook_size
+        self.embed_dim = emb_dim
+        self.ch_mult = ch_mult
+        self.resolution = img_size
+        self.attn_resolutions = attn_resolutions
+        self.quantizer_type = quantizer
+        self.encoder = Encoder(
+            self.in_channels,
+            self.nf,
+            self.embed_dim,
+            self.ch_mult,
+            self.n_blocks,
+            self.resolution,
+            self.attn_resolutions
+        )
+        if self.quantizer_type == "nearest":
+            self.beta = beta #0.25
+            self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+        elif self.quantizer_type == "gumbel":
+            self.gumbel_num_hiddens = emb_dim
+            self.straight_through = gumbel_straight_through
+            self.kl_weight = gumbel_kl_weight
+            self.quantize = GumbelQuantizer(
+                self.codebook_size,
+                self.embed_dim,
+                self.gumbel_num_hiddens,
+                self.straight_through,
+                self.kl_weight
+            )
+        self.generator = Generator(
+            self.nf, 
+            self.embed_dim,
+            self.ch_mult, 
+            self.n_blocks, 
+            self.resolution, 
+            self.attn_resolutions
+        )
+
+        if model_path is not None:
+            chkpt = torch.load(model_path, map_location='cpu')
+            if 'params_ema' in chkpt:
+                self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
+                logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
+            elif 'params' in chkpt:
+                self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+                logger.info(f'vqgan is loaded from: {model_path} [params]')
+            else:
+                raise ValueError(f'Wrong params!')
+
+
+    def forward(self, x):
+        x = self.encoder(x)
+        quant, codebook_loss, quant_stats = self.quantize(x)
+        x = self.generator(quant)
+        return x, codebook_loss, quant_stats
+
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+    def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+        super().__init__()
+
+        layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+        ndf_mult = 1
+        ndf_mult_prev = 1
+        for n in range(1, n_layers):  # gradually increase the number of filters
+            ndf_mult_prev = ndf_mult
+            ndf_mult = min(2 ** n, 8)
+            layers += [
+                nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+                nn.BatchNorm2d(ndf * ndf_mult),
+                nn.LeakyReLU(0.2, True)
+            ]
+
+        ndf_mult_prev = ndf_mult
+        ndf_mult = min(2 ** n_layers, 8)
+
+        layers += [
+            nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+            nn.BatchNorm2d(ndf * ndf_mult),
+            nn.LeakyReLU(0.2, True)
+        ]
+
+        layers += [
+            nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)]  # output 1 channel prediction map
+        self.main = nn.Sequential(*layers)
+
+        if model_path is not None:
+            chkpt = torch.load(model_path, map_location='cpu')
+            if 'params_d' in chkpt:
+                self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
+            elif 'params' in chkpt:
+                self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+            else:
+                raise ValueError(f'Wrong params!')
+
+    def forward(self, x):
+        return self.main(x)
\ No newline at end of file
diff --git a/CodeFormer/basicsr/data/__init__.py b/CodeFormer/basicsr/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6adb4bb6a926af7a46aaec4794eee95fda02a33
--- /dev/null
+++ b/CodeFormer/basicsr/data/__init__.py
@@ -0,0 +1,100 @@
+import importlib
+import numpy as np
+import random
+import torch
+import torch.utils.data
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+from basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ['build_dataset', 'build_dataloader']
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+    """Build dataset from options.
+
+    Args:
+        dataset_opt (dict): Configuration for dataset. It must constain:
+            name (str): Dataset name.
+            type (str): Dataset type.
+    """
+    dataset_opt = deepcopy(dataset_opt)
+    dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
+    logger = get_root_logger()
+    logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
+    return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+    """Build dataloader.
+
+    Args:
+        dataset (torch.utils.data.Dataset): Dataset.
+        dataset_opt (dict): Dataset options. It contains the following keys:
+            phase (str): 'train' or 'val'.
+            num_worker_per_gpu (int): Number of workers for each GPU.
+            batch_size_per_gpu (int): Training batch size for each GPU.
+        num_gpu (int): Number of GPUs. Used only in the train phase.
+            Default: 1.
+        dist (bool): Whether in distributed training. Used only in the train
+            phase. Default: False.
+        sampler (torch.utils.data.sampler): Data sampler. Default: None.
+        seed (int | None): Seed. Default: None
+    """
+    phase = dataset_opt['phase']
+    rank, _ = get_dist_info()
+    if phase == 'train':
+        if dist:  # distributed training
+            batch_size = dataset_opt['batch_size_per_gpu']
+            num_workers = dataset_opt['num_worker_per_gpu']
+        else:  # non-distributed training
+            multiplier = 1 if num_gpu == 0 else num_gpu
+            batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
+            num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
+        dataloader_args = dict(
+            dataset=dataset,
+            batch_size=batch_size,
+            shuffle=False,
+            num_workers=num_workers,
+            sampler=sampler,
+            drop_last=True)
+        if sampler is None:
+            dataloader_args['shuffle'] = True
+        dataloader_args['worker_init_fn'] = partial(
+            worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+    elif phase in ['val', 'test']:  # validation
+        dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+    else:
+        raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
+
+    dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
+
+    prefetch_mode = dataset_opt.get('prefetch_mode')
+    if prefetch_mode == 'cpu':  # CPUPrefetcher
+        num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
+        logger = get_root_logger()
+        logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
+        return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+    else:
+        # prefetch_mode=None: Normal dataloader
+        # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+        return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+    # Set the worker seed to num_workers * rank + worker_id + seed
+    worker_seed = num_workers * rank + worker_id + seed
+    np.random.seed(worker_seed)
+    random.seed(worker_seed)
diff --git a/CodeFormer/basicsr/data/data_sampler.py b/CodeFormer/basicsr/data/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2
--- /dev/null
+++ b/CodeFormer/basicsr/data/data_sampler.py
@@ -0,0 +1,48 @@
+import math
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+    """Sampler that restricts data loading to a subset of the dataset.
+
+    Modified from torch.utils.data.distributed.DistributedSampler
+    Support enlarging the dataset for iteration-based training, for saving
+    time when restart the dataloader after each epoch
+
+    Args:
+        dataset (torch.utils.data.Dataset): Dataset used for sampling.
+        num_replicas (int | None): Number of processes participating in
+            the training. It is usually the world_size.
+        rank (int | None): Rank of the current process within num_replicas.
+        ratio (int): Enlarging ratio. Default: 1.
+    """
+
+    def __init__(self, dataset, num_replicas, rank, ratio=1):
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+        self.total_size = self.num_samples * self.num_replicas
+
+    def __iter__(self):
+        # deterministically shuffle based on epoch
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        indices = torch.randperm(self.total_size, generator=g).tolist()
+
+        dataset_size = len(self.dataset)
+        indices = [v % dataset_size for v in indices]
+
+        # subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        return iter(indices)
+
+    def __len__(self):
+        return self.num_samples
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
diff --git a/CodeFormer/basicsr/data/data_util.py b/CodeFormer/basicsr/data/data_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b1bce8e089485182c962e830a163d6d0059da8
--- /dev/null
+++ b/CodeFormer/basicsr/data/data_util.py
@@ -0,0 +1,305 @@
+import cv2
+import numpy as np
+import torch
+from os import path as osp
+from torch.nn import functional as F
+
+from basicsr.data.transforms import mod_crop
+from basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1):
+    """Read a sequence of images from a given folder path.
+
+    Args:
+        path (list[str] | str): List of image paths or image folder path.
+        require_mod_crop (bool): Require mod crop for each image.
+            Default: False.
+        scale (int): Scale factor for mod_crop. Default: 1.
+
+    Returns:
+        Tensor: size (t, c, h, w), RGB, [0, 1].
+    """
+    if isinstance(path, list):
+        img_paths = path
+    else:
+        img_paths = sorted(list(scandir(path, full_path=True)))
+    imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+    if require_mod_crop:
+        imgs = [mod_crop(img, scale) for img in imgs]
+    imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+    imgs = torch.stack(imgs, dim=0)
+    return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
+    """Generate an index list for reading `num_frames` frames from a sequence
+    of images.
+
+    Args:
+        crt_idx (int): Current center index.
+        max_frame_num (int): Max number of the sequence of images (from 1).
+        num_frames (int): Reading num_frames frames.
+        padding (str): Padding mode, one of
+            'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+            Examples: current_idx = 0, num_frames = 5
+            The generated frame indices under different padding mode:
+            replicate: [0, 0, 0, 1, 2]
+            reflection: [2, 1, 0, 1, 2]
+            reflection_circle: [4, 3, 0, 1, 2]
+            circle: [3, 4, 0, 1, 2]
+
+    Returns:
+        list[int]: A list of indices.
+    """
+    assert num_frames % 2 == 1, 'num_frames should be an odd number.'
+    assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
+
+    max_frame_num = max_frame_num - 1  # start from 0
+    num_pad = num_frames // 2
+
+    indices = []
+    for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+        if i < 0:
+            if padding == 'replicate':
+                pad_idx = 0
+            elif padding == 'reflection':
+                pad_idx = -i
+            elif padding == 'reflection_circle':
+                pad_idx = crt_idx + num_pad - i
+            else:
+                pad_idx = num_frames + i
+        elif i > max_frame_num:
+            if padding == 'replicate':
+                pad_idx = max_frame_num
+            elif padding == 'reflection':
+                pad_idx = max_frame_num * 2 - i
+            elif padding == 'reflection_circle':
+                pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+            else:
+                pad_idx = i - num_frames
+        else:
+            pad_idx = i
+        indices.append(pad_idx)
+    return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+    """Generate paired paths from lmdb files.
+
+    Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+    lq.lmdb
+    ├── data.mdb
+    ├── lock.mdb
+    ├── meta_info.txt
+
+    The data.mdb and lock.mdb are standard lmdb files and you can refer to
+    https://lmdb.readthedocs.io/en/release/ for more details.
+
+    The meta_info.txt is a specified txt file to record the meta information
+    of our datasets. It will be automatically created when preparing
+    datasets by our provided dataset tools.
+    Each line in the txt file records
+    1)image name (with extension),
+    2)image shape,
+    3)compression level, separated by a white space.
+    Example: `baboon.png (120,125,3) 1`
+
+    We use the image name without extension as the lmdb key.
+    Note that we use the same key for the corresponding lq and gt images.
+
+    Args:
+        folders (list[str]): A list of folder path. The order of list should
+            be [input_folder, gt_folder].
+        keys (list[str]): A list of keys identifying folders. The order should
+            be in consistent with folders, e.g., ['lq', 'gt'].
+            Note that this key is different from lmdb keys.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+                               f'But got {len(folders)}')
+    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+    input_folder, gt_folder = folders
+    input_key, gt_key = keys
+
+    if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
+        raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
+                         f'formats. But received {input_key}: {input_folder}; '
+                         f'{gt_key}: {gt_folder}')
+    # ensure that the two meta_info files are the same
+    with open(osp.join(input_folder, 'meta_info.txt')) as fin:
+        input_lmdb_keys = [line.split('.')[0] for line in fin]
+    with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
+        gt_lmdb_keys = [line.split('.')[0] for line in fin]
+    if set(input_lmdb_keys) != set(gt_lmdb_keys):
+        raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
+    else:
+        paths = []
+        for lmdb_key in sorted(input_lmdb_keys):
+            paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
+        return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+    """Generate paired paths from an meta information file.
+
+    Each line in the meta information file contains the image names and
+    image shape (usually for gt), separated by a white space.
+
+    Example of an meta information file:
+    ```
+    0001_s001.png (480,480,3)
+    0001_s002.png (480,480,3)
+    ```
+
+    Args:
+        folders (list[str]): A list of folder path. The order of list should
+            be [input_folder, gt_folder].
+        keys (list[str]): A list of keys identifying folders. The order should
+            be in consistent with folders, e.g., ['lq', 'gt'].
+        meta_info_file (str): Path to the meta information file.
+        filename_tmpl (str): Template for each filename. Note that the
+            template excludes the file extension. Usually the filename_tmpl is
+            for files in the input folder.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+                               f'But got {len(folders)}')
+    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+    input_folder, gt_folder = folders
+    input_key, gt_key = keys
+
+    with open(meta_info_file, 'r') as fin:
+        gt_names = [line.split(' ')[0] for line in fin]
+
+    paths = []
+    for gt_name in gt_names:
+        basename, ext = osp.splitext(osp.basename(gt_name))
+        input_name = f'{filename_tmpl.format(basename)}{ext}'
+        input_path = osp.join(input_folder, input_name)
+        gt_path = osp.join(gt_folder, gt_name)
+        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+    return paths
+
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+    """Generate paired paths from folders.
+
+    Args:
+        folders (list[str]): A list of folder path. The order of list should
+            be [input_folder, gt_folder].
+        keys (list[str]): A list of keys identifying folders. The order should
+            be in consistent with folders, e.g., ['lq', 'gt'].
+        filename_tmpl (str): Template for each filename. Note that the
+            template excludes the file extension. Usually the filename_tmpl is
+            for files in the input folder.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+                               f'But got {len(folders)}')
+    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+    input_folder, gt_folder = folders
+    input_key, gt_key = keys
+
+    input_paths = list(scandir(input_folder))
+    gt_paths = list(scandir(gt_folder))
+    assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
+                                               f'{len(input_paths)}, {len(gt_paths)}.')
+    paths = []
+    for gt_path in gt_paths:
+        basename, ext = osp.splitext(osp.basename(gt_path))
+        input_name = f'{filename_tmpl.format(basename)}{ext}'
+        input_path = osp.join(input_folder, input_name)
+        assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
+        gt_path = osp.join(gt_folder, gt_path)
+        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+    return paths
+
+
+def paths_from_folder(folder):
+    """Generate paths from folder.
+
+    Args:
+        folder (str): Folder path.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+
+    paths = list(scandir(folder))
+    paths = [osp.join(folder, path) for path in paths]
+    return paths
+
+
+def paths_from_lmdb(folder):
+    """Generate paths from lmdb.
+
+    Args:
+        folder (str): Folder path.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    if not folder.endswith('.lmdb'):
+        raise ValueError(f'Folder {folder}folder should in lmdb format.')
+    with open(osp.join(folder, 'meta_info.txt')) as fin:
+        paths = [line.split('.')[0] for line in fin]
+    return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+    """Generate Gaussian kernel used in `duf_downsample`.
+
+    Args:
+        kernel_size (int): Kernel size. Default: 13.
+        sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+    Returns:
+        np.array: The Gaussian kernel.
+    """
+    from scipy.ndimage import filters as filters
+    kernel = np.zeros((kernel_size, kernel_size))
+    # set element at the middle to one, a dirac delta
+    kernel[kernel_size // 2, kernel_size // 2] = 1
+    # gaussian-smooth the dirac, resulting in a gaussian filter
+    return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+    """Downsamping with Gaussian kernel used in the DUF official code.
+
+    Args:
+        x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+        kernel_size (int): Kernel size. Default: 13.
+        scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+            Default: 4.
+
+    Returns:
+        Tensor: DUF downsampled frames.
+    """
+    assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
+
+    squeeze_flag = False
+    if x.ndim == 4:
+        squeeze_flag = True
+        x = x.unsqueeze(0)
+    b, t, c, h, w = x.size()
+    x = x.view(-1, 1, h, w)
+    pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+    x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
+
+    gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+    gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+    x = F.conv2d(x, gaussian_filter, stride=scale)
+    x = x[:, :, 2:-2, 2:-2]
+    x = x.view(b, t, c, x.size(2), x.size(3))
+    if squeeze_flag:
+        x = x.squeeze(0)
+    return x
diff --git a/CodeFormer/basicsr/data/prefetch_dataloader.py b/CodeFormer/basicsr/data/prefetch_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0
--- /dev/null
+++ b/CodeFormer/basicsr/data/prefetch_dataloader.py
@@ -0,0 +1,125 @@
+import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+    """A general prefetch generator.
+
+    Ref:
+    https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+    Args:
+        generator: Python generator.
+        num_prefetch_queue (int): Number of prefetch queue.
+    """
+
+    def __init__(self, generator, num_prefetch_queue):
+        threading.Thread.__init__(self)
+        self.queue = Queue.Queue(num_prefetch_queue)
+        self.generator = generator
+        self.daemon = True
+        self.start()
+
+    def run(self):
+        for item in self.generator:
+            self.queue.put(item)
+        self.queue.put(None)
+
+    def __next__(self):
+        next_item = self.queue.get()
+        if next_item is None:
+            raise StopIteration
+        return next_item
+
+    def __iter__(self):
+        return self
+
+
+class PrefetchDataLoader(DataLoader):
+    """Prefetch version of dataloader.
+
+    Ref:
+    https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+    TODO:
+    Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+    ddp.
+
+    Args:
+        num_prefetch_queue (int): Number of prefetch queue.
+        kwargs (dict): Other arguments for dataloader.
+    """
+
+    def __init__(self, num_prefetch_queue, **kwargs):
+        self.num_prefetch_queue = num_prefetch_queue
+        super(PrefetchDataLoader, self).__init__(**kwargs)
+
+    def __iter__(self):
+        return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+    """CPU prefetcher.
+
+    Args:
+        loader: Dataloader.
+    """
+
+    def __init__(self, loader):
+        self.ori_loader = loader
+        self.loader = iter(loader)
+
+    def next(self):
+        try:
+            return next(self.loader)
+        except StopIteration:
+            return None
+
+    def reset(self):
+        self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+    """CUDA prefetcher.
+
+    Ref:
+    https://github.com/NVIDIA/apex/issues/304#
+
+    It may consums more GPU memory.
+
+    Args:
+        loader: Dataloader.
+        opt (dict): Options.
+    """
+
+    def __init__(self, loader, opt):
+        self.ori_loader = loader
+        self.loader = iter(loader)
+        self.opt = opt
+        self.stream = torch.cuda.Stream()
+        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+        self.preload()
+
+    def preload(self):
+        try:
+            self.batch = next(self.loader)  # self.batch is a dict
+        except StopIteration:
+            self.batch = None
+            return None
+        # put tensors to gpu
+        with torch.cuda.stream(self.stream):
+            for k, v in self.batch.items():
+                if torch.is_tensor(v):
+                    self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+    def next(self):
+        torch.cuda.current_stream().wait_stream(self.stream)
+        batch = self.batch
+        self.preload()
+        return batch
+
+    def reset(self):
+        self.loader = iter(self.ori_loader)
+        self.preload()
diff --git a/CodeFormer/basicsr/data/transforms.py b/CodeFormer/basicsr/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..aead9dc73ed063e1c5865040eaa2652b26aa3ad3
--- /dev/null
+++ b/CodeFormer/basicsr/data/transforms.py
@@ -0,0 +1,165 @@
+import cv2
+import random
+
+
+def mod_crop(img, scale):
+    """Mod crop images, used during testing.
+
+    Args:
+        img (ndarray): Input image.
+        scale (int): Scale factor.
+
+    Returns:
+        ndarray: Result image.
+    """
+    img = img.copy()
+    if img.ndim in (2, 3):
+        h, w = img.shape[0], img.shape[1]
+        h_remainder, w_remainder = h % scale, w % scale
+        img = img[:h - h_remainder, :w - w_remainder, ...]
+    else:
+        raise ValueError(f'Wrong img ndim: {img.ndim}.')
+    return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
+    """Paired random crop.
+
+    It crops lists of lq and gt images with corresponding locations.
+
+    Args:
+        img_gts (list[ndarray] | ndarray): GT images. Note that all images
+            should have the same shape. If the input is an ndarray, it will
+            be transformed to a list containing itself.
+        img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+            should have the same shape. If the input is an ndarray, it will
+            be transformed to a list containing itself.
+        gt_patch_size (int): GT patch size.
+        scale (int): Scale factor.
+        gt_path (str): Path to ground-truth.
+
+    Returns:
+        list[ndarray] | ndarray: GT images and LQ images. If returned results
+            only have one element, just return ndarray.
+    """
+
+    if not isinstance(img_gts, list):
+        img_gts = [img_gts]
+    if not isinstance(img_lqs, list):
+        img_lqs = [img_lqs]
+
+    h_lq, w_lq, _ = img_lqs[0].shape
+    h_gt, w_gt, _ = img_gts[0].shape
+    lq_patch_size = gt_patch_size // scale
+
+    if h_gt != h_lq * scale or w_gt != w_lq * scale:
+        raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+                         f'multiplication of LQ ({h_lq}, {w_lq}).')
+    if h_lq < lq_patch_size or w_lq < lq_patch_size:
+        raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+                         f'({lq_patch_size}, {lq_patch_size}). '
+                         f'Please remove {gt_path}.')
+
+    # randomly choose top and left coordinates for lq patch
+    top = random.randint(0, h_lq - lq_patch_size)
+    left = random.randint(0, w_lq - lq_patch_size)
+
+    # crop lq patch
+    img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+    # crop corresponding gt patch
+    top_gt, left_gt = int(top * scale), int(left * scale)
+    img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+    if len(img_gts) == 1:
+        img_gts = img_gts[0]
+    if len(img_lqs) == 1:
+        img_lqs = img_lqs[0]
+    return img_gts, img_lqs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+    """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+    We use vertical flip and transpose for rotation implementation.
+    All the images in the list use the same augmentation.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+            is an ndarray, it will be transformed to a list.
+        hflip (bool): Horizontal flip. Default: True.
+        rotation (bool): Ratotation. Default: True.
+        flows (list[ndarray]: Flows to be augmented. If the input is an
+            ndarray, it will be transformed to a list.
+            Dimension is (h, w, 2). Default: None.
+        return_status (bool): Return the status of flip and rotation.
+            Default: False.
+
+    Returns:
+        list[ndarray] | ndarray: Augmented images and flows. If returned
+            results only have one element, just return ndarray.
+
+    """
+    hflip = hflip and random.random() < 0.5
+    vflip = rotation and random.random() < 0.5
+    rot90 = rotation and random.random() < 0.5
+
+    def _augment(img):
+        if hflip:  # horizontal
+            cv2.flip(img, 1, img)
+        if vflip:  # vertical
+            cv2.flip(img, 0, img)
+        if rot90:
+            img = img.transpose(1, 0, 2)
+        return img
+
+    def _augment_flow(flow):
+        if hflip:  # horizontal
+            cv2.flip(flow, 1, flow)
+            flow[:, :, 0] *= -1
+        if vflip:  # vertical
+            cv2.flip(flow, 0, flow)
+            flow[:, :, 1] *= -1
+        if rot90:
+            flow = flow.transpose(1, 0, 2)
+            flow = flow[:, :, [1, 0]]
+        return flow
+
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    imgs = [_augment(img) for img in imgs]
+    if len(imgs) == 1:
+        imgs = imgs[0]
+
+    if flows is not None:
+        if not isinstance(flows, list):
+            flows = [flows]
+        flows = [_augment_flow(flow) for flow in flows]
+        if len(flows) == 1:
+            flows = flows[0]
+        return imgs, flows
+    else:
+        if return_status:
+            return imgs, (hflip, vflip, rot90)
+        else:
+            return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+    """Rotate image.
+
+    Args:
+        img (ndarray): Image to be rotated.
+        angle (float): Rotation angle in degrees. Positive values mean
+            counter-clockwise rotation.
+        center (tuple[int]): Rotation center. If the center is None,
+            initialize it as the center of the image. Default: None.
+        scale (float): Isotropic scale factor. Default: 1.0.
+    """
+    (h, w) = img.shape[:2]
+
+    if center is None:
+        center = (w // 2, h // 2)
+
+    matrix = cv2.getRotationMatrix2D(center, angle, scale)
+    rotated_img = cv2.warpAffine(img, matrix, (w, h))
+    return rotated_img
diff --git a/CodeFormer/basicsr/losses/__init__.py b/CodeFormer/basicsr/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b184e74c861e6fca0c548692a9a949a6100b0aa
--- /dev/null
+++ b/CodeFormer/basicsr/losses/__init__.py
@@ -0,0 +1,26 @@
+from copy import deepcopy
+
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import LOSS_REGISTRY
+from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
+                     gradient_penalty_loss, r1_penalty)
+
+__all__ = [
+    'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
+    'r1_penalty', 'g_path_regularize'
+]
+
+
+def build_loss(opt):
+    """Build loss from options.
+
+    Args:
+        opt (dict): Configuration. It must constain:
+            type (str): Model type.
+    """
+    opt = deepcopy(opt)
+    loss_type = opt.pop('type')
+    loss = LOSS_REGISTRY.get(loss_type)(**opt)
+    logger = get_root_logger()
+    logger.info(f'Loss [{loss.__class__.__name__}] is created.')
+    return loss
diff --git a/CodeFormer/basicsr/losses/loss_util.py b/CodeFormer/basicsr/losses/loss_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..744eeb46d1f3b5a7b4553ca23237ddd9c899a698
--- /dev/null
+++ b/CodeFormer/basicsr/losses/loss_util.py
@@ -0,0 +1,95 @@
+import functools
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+    """Reduce loss as specified.
+
+    Args:
+        loss (Tensor): Elementwise loss tensor.
+        reduction (str): Options are 'none', 'mean' and 'sum'.
+
+    Returns:
+        Tensor: Reduced loss tensor.
+    """
+    reduction_enum = F._Reduction.get_enum(reduction)
+    # none: 0, elementwise_mean:1, sum: 2
+    if reduction_enum == 0:
+        return loss
+    elif reduction_enum == 1:
+        return loss.mean()
+    else:
+        return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean'):
+    """Apply element-wise weight and reduce loss.
+
+    Args:
+        loss (Tensor): Element-wise loss.
+        weight (Tensor): Element-wise weights. Default: None.
+        reduction (str): Same as built-in losses of PyTorch. Options are
+            'none', 'mean' and 'sum'. Default: 'mean'.
+
+    Returns:
+        Tensor: Loss values.
+    """
+    # if weight is specified, apply element-wise weight
+    if weight is not None:
+        assert weight.dim() == loss.dim()
+        assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+        loss = loss * weight
+
+    # if weight is not specified or reduction is sum, just reduce the loss
+    if weight is None or reduction == 'sum':
+        loss = reduce_loss(loss, reduction)
+    # if reduction is mean, then compute mean over weight region
+    elif reduction == 'mean':
+        if weight.size(1) > 1:
+            weight = weight.sum()
+        else:
+            weight = weight.sum() * loss.size(1)
+        loss = loss.sum() / weight
+
+    return loss
+
+
+def weighted_loss(loss_func):
+    """Create a weighted version of a given loss function.
+
+    To use this decorator, the loss function must have the signature like
+    `loss_func(pred, target, **kwargs)`. The function only needs to compute
+    element-wise loss without any reduction. This decorator will add weight
+    and reduction arguments to the function. The decorated function will have
+    the signature like `loss_func(pred, target, weight=None, reduction='mean',
+    **kwargs)`.
+
+    :Example:
+
+    >>> import torch
+    >>> @weighted_loss
+    >>> def l1_loss(pred, target):
+    >>>     return (pred - target).abs()
+
+    >>> pred = torch.Tensor([0, 2, 3])
+    >>> target = torch.Tensor([1, 1, 1])
+    >>> weight = torch.Tensor([1, 0, 1])
+
+    >>> l1_loss(pred, target)
+    tensor(1.3333)
+    >>> l1_loss(pred, target, weight)
+    tensor(1.5000)
+    >>> l1_loss(pred, target, reduction='none')
+    tensor([1., 1., 2.])
+    >>> l1_loss(pred, target, weight, reduction='sum')
+    tensor(3.)
+    """
+
+    @functools.wraps(loss_func)
+    def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
+        # get element-wise loss
+        loss = loss_func(pred, target, **kwargs)
+        loss = weight_reduce_loss(loss, weight, reduction)
+        return loss
+
+    return wrapper
diff --git a/CodeFormer/basicsr/losses/losses.py b/CodeFormer/basicsr/losses/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bcf272cfb756d99451a3005567ea4d4c9059067
--- /dev/null
+++ b/CodeFormer/basicsr/losses/losses.py
@@ -0,0 +1,455 @@
+import math
+import lpips
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.archs.vgg_arch import VGGFeatureExtractor
+from basicsr.utils.registry import LOSS_REGISTRY
+from .loss_util import weighted_loss
+
+_reduction_modes = ['none', 'mean', 'sum']
+
+
+@weighted_loss
+def l1_loss(pred, target):
+    return F.l1_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def mse_loss(pred, target):
+    return F.mse_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def charbonnier_loss(pred, target, eps=1e-12):
+    return torch.sqrt((pred - target)**2 + eps)
+
+
+@LOSS_REGISTRY.register()
+class L1Loss(nn.Module):
+    """L1 (mean absolute error, MAE) loss.
+
+    Args:
+        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+        reduction (str): Specifies the reduction to apply to the output.
+            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+    """
+
+    def __init__(self, loss_weight=1.0, reduction='mean'):
+        super(L1Loss, self).__init__()
+        if reduction not in ['none', 'mean', 'sum']:
+            raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+        self.loss_weight = loss_weight
+        self.reduction = reduction
+
+    def forward(self, pred, target, weight=None, **kwargs):
+        """
+        Args:
+            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+                weights. Default: None.
+        """
+        return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class MSELoss(nn.Module):
+    """MSE (L2) loss.
+
+    Args:
+        loss_weight (float): Loss weight for MSE loss. Default: 1.0.
+        reduction (str): Specifies the reduction to apply to the output.
+            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+    """
+
+    def __init__(self, loss_weight=1.0, reduction='mean'):
+        super(MSELoss, self).__init__()
+        if reduction not in ['none', 'mean', 'sum']:
+            raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+        self.loss_weight = loss_weight
+        self.reduction = reduction
+
+    def forward(self, pred, target, weight=None, **kwargs):
+        """
+        Args:
+            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+                weights. Default: None.
+        """
+        return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class CharbonnierLoss(nn.Module):
+    """Charbonnier loss (one variant of Robust L1Loss, a differentiable
+    variant of L1Loss).
+
+    Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
+        Super-Resolution".
+
+    Args:
+        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+        reduction (str): Specifies the reduction to apply to the output.
+            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+        eps (float): A value used to control the curvature near zero.
+            Default: 1e-12.
+    """
+
+    def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
+        super(CharbonnierLoss, self).__init__()
+        if reduction not in ['none', 'mean', 'sum']:
+            raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+        self.loss_weight = loss_weight
+        self.reduction = reduction
+        self.eps = eps
+
+    def forward(self, pred, target, weight=None, **kwargs):
+        """
+        Args:
+            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+                weights. Default: None.
+        """
+        return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class WeightedTVLoss(L1Loss):
+    """Weighted TV loss.
+
+        Args:
+            loss_weight (float): Loss weight. Default: 1.0.
+    """
+
+    def __init__(self, loss_weight=1.0):
+        super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
+
+    def forward(self, pred, weight=None):
+        y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
+        x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
+
+        loss = x_diff + y_diff
+
+        return loss
+
+
+@LOSS_REGISTRY.register()
+class PerceptualLoss(nn.Module):
+    """Perceptual loss with commonly used style loss.
+
+    Args:
+        layer_weights (dict): The weight for each layer of vgg feature.
+            Here is an example: {'conv5_4': 1.}, which means the conv5_4
+            feature layer (before relu5_4) will be extracted with weight
+            1.0 in calculting losses.
+        vgg_type (str): The type of vgg network used as feature extractor.
+            Default: 'vgg19'.
+        use_input_norm (bool):  If True, normalize the input image in vgg.
+            Default: True.
+        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+            Default: False.
+        perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+            loss will be calculated and the loss will multiplied by the
+            weight. Default: 1.0.
+        style_weight (float): If `style_weight > 0`, the style loss will be
+            calculated and the loss will multiplied by the weight.
+            Default: 0.
+        criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+    """
+
+    def __init__(self,
+                 layer_weights,
+                 vgg_type='vgg19',
+                 use_input_norm=True,
+                 range_norm=False,
+                 perceptual_weight=1.0,
+                 style_weight=0.,
+                 criterion='l1'):
+        super(PerceptualLoss, self).__init__()
+        self.perceptual_weight = perceptual_weight
+        self.style_weight = style_weight
+        self.layer_weights = layer_weights
+        self.vgg = VGGFeatureExtractor(
+            layer_name_list=list(layer_weights.keys()),
+            vgg_type=vgg_type,
+            use_input_norm=use_input_norm,
+            range_norm=range_norm)
+
+        self.criterion_type = criterion
+        if self.criterion_type == 'l1':
+            self.criterion = torch.nn.L1Loss()
+        elif self.criterion_type == 'l2':
+            self.criterion = torch.nn.L2loss()
+        elif self.criterion_type == 'mse':
+            self.criterion = torch.nn.MSELoss(reduction='mean')
+        elif self.criterion_type == 'fro':
+            self.criterion = None
+        else:
+            raise NotImplementedError(f'{criterion} criterion has not been supported.')
+
+    def forward(self, x, gt):
+        """Forward function.
+
+        Args:
+            x (Tensor): Input tensor with shape (n, c, h, w).
+            gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+        Returns:
+            Tensor: Forward results.
+        """
+        # extract vgg features
+        x_features = self.vgg(x)
+        gt_features = self.vgg(gt.detach())
+
+        # calculate perceptual loss
+        if self.perceptual_weight > 0:
+            percep_loss = 0
+            for k in x_features.keys():
+                if self.criterion_type == 'fro':
+                    percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
+                else:
+                    percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+            percep_loss *= self.perceptual_weight
+        else:
+            percep_loss = None
+
+        # calculate style loss
+        if self.style_weight > 0:
+            style_loss = 0
+            for k in x_features.keys():
+                if self.criterion_type == 'fro':
+                    style_loss += torch.norm(
+                        self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+                else:
+                    style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
+                        gt_features[k])) * self.layer_weights[k]
+            style_loss *= self.style_weight
+        else:
+            style_loss = None
+
+        return percep_loss, style_loss
+
+    def _gram_mat(self, x):
+        """Calculate Gram matrix.
+
+        Args:
+            x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+        Returns:
+            torch.Tensor: Gram matrix.
+        """
+        n, c, h, w = x.size()
+        features = x.view(n, c, w * h)
+        features_t = features.transpose(1, 2)
+        gram = features.bmm(features_t) / (c * h * w)
+        return gram
+
+
+@LOSS_REGISTRY.register()
+class LPIPSLoss(nn.Module):
+    def __init__(self, 
+            loss_weight=1.0, 
+            use_input_norm=True,
+            range_norm=False,):
+        super(LPIPSLoss, self).__init__()
+        self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
+        self.loss_weight = loss_weight
+        self.use_input_norm = use_input_norm
+        self.range_norm = range_norm
+
+        if self.use_input_norm:
+            # the mean is for image with range [0, 1]
+            self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+            # the std is for image with range [0, 1]
+            self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+    def forward(self, pred, target):
+        if self.range_norm:
+            pred   = (pred + 1) / 2
+            target = (target + 1) / 2
+        if self.use_input_norm:
+            pred   = (pred - self.mean) / self.std
+            target = (target - self.mean) / self.std
+        lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
+        return self.loss_weight * lpips_loss.mean()
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+    """Define GAN loss.
+
+    Args:
+        gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+        real_label_val (float): The value for real label. Default: 1.0.
+        fake_label_val (float): The value for fake label. Default: 0.0.
+        loss_weight (float): Loss weight. Default: 1.0.
+            Note that loss_weight is only for generators; and it is always 1.0
+            for discriminators.
+    """
+
+    def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+        super(GANLoss, self).__init__()
+        self.gan_type = gan_type
+        self.loss_weight = loss_weight
+        self.real_label_val = real_label_val
+        self.fake_label_val = fake_label_val
+
+        if self.gan_type == 'vanilla':
+            self.loss = nn.BCEWithLogitsLoss()
+        elif self.gan_type == 'lsgan':
+            self.loss = nn.MSELoss()
+        elif self.gan_type == 'wgan':
+            self.loss = self._wgan_loss
+        elif self.gan_type == 'wgan_softplus':
+            self.loss = self._wgan_softplus_loss
+        elif self.gan_type == 'hinge':
+            self.loss = nn.ReLU()
+        else:
+            raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
+
+    def _wgan_loss(self, input, target):
+        """wgan loss.
+
+        Args:
+            input (Tensor): Input tensor.
+            target (bool): Target label.
+
+        Returns:
+            Tensor: wgan loss.
+        """
+        return -input.mean() if target else input.mean()
+
+    def _wgan_softplus_loss(self, input, target):
+        """wgan loss with soft plus. softplus is a smooth approximation to the
+        ReLU function.
+
+        In StyleGAN2, it is called:
+            Logistic loss for discriminator;
+            Non-saturating loss for generator.
+
+        Args:
+            input (Tensor): Input tensor.
+            target (bool): Target label.
+
+        Returns:
+            Tensor: wgan loss.
+        """
+        return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+    def get_target_label(self, input, target_is_real):
+        """Get target label.
+
+        Args:
+            input (Tensor): Input tensor.
+            target_is_real (bool): Whether the target is real or fake.
+
+        Returns:
+            (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+                return Tensor.
+        """
+
+        if self.gan_type in ['wgan', 'wgan_softplus']:
+            return target_is_real
+        target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+        return input.new_ones(input.size()) * target_val
+
+    def forward(self, input, target_is_real, is_disc=False):
+        """
+        Args:
+            input (Tensor): The input for the loss module, i.e., the network
+                prediction.
+            target_is_real (bool): Whether the targe is real or fake.
+            is_disc (bool): Whether the loss for discriminators or not.
+                Default: False.
+
+        Returns:
+            Tensor: GAN loss value.
+        """
+        if self.gan_type == 'hinge':
+            if is_disc:  # for discriminators in hinge-gan
+                input = -input if target_is_real else input
+                loss = self.loss(1 + input).mean()
+            else:  # for generators in hinge-gan
+                loss = -input.mean()
+        else:  # other gan types
+            target_label = self.get_target_label(input, target_is_real)
+            loss = self.loss(input, target_label)
+
+        # loss_weight is always 1.0 for discriminators
+        return loss if is_disc else loss * self.loss_weight
+
+
+def r1_penalty(real_pred, real_img):
+    """R1 regularization for discriminator. The core idea is to
+        penalize the gradient on real data alone: when the
+        generator distribution produces the true data distribution
+        and the discriminator is equal to 0 on the data manifold, the
+        gradient penalty ensures that the discriminator cannot create
+        a non-zero gradient orthogonal to the data manifold without
+        suffering a loss in the GAN game.
+
+        Ref:
+        Eq. 9 in Which training methods for GANs do actually converge.
+        """
+    grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+    grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+    return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+    noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+    grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+    path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+    return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+    """Calculate gradient penalty for wgan-gp.
+
+    Args:
+        discriminator (nn.Module): Network for the discriminator.
+        real_data (Tensor): Real input data.
+        fake_data (Tensor): Fake input data.
+        weight (Tensor): Weight tensor. Default: None.
+
+    Returns:
+        Tensor: A tensor for gradient penalty.
+    """
+
+    batch_size = real_data.size(0)
+    alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+    # interpolate between real_data and fake_data
+    interpolates = alpha * real_data + (1. - alpha) * fake_data
+    interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+    disc_interpolates = discriminator(interpolates)
+    gradients = autograd.grad(
+        outputs=disc_interpolates,
+        inputs=interpolates,
+        grad_outputs=torch.ones_like(disc_interpolates),
+        create_graph=True,
+        retain_graph=True,
+        only_inputs=True)[0]
+
+    if weight is not None:
+        gradients = gradients * weight
+
+    gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+    if weight is not None:
+        gradients_penalty /= torch.mean(weight)
+
+    return gradients_penalty
diff --git a/CodeFormer/basicsr/metrics/__init__.py b/CodeFormer/basicsr/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19d55cc8321f124c918d78465b053aef67f13a33
--- /dev/null
+++ b/CodeFormer/basicsr/metrics/__init__.py
@@ -0,0 +1,19 @@
+from copy import deepcopy
+
+from basicsr.utils.registry import METRIC_REGISTRY
+from .psnr_ssim import calculate_psnr, calculate_ssim
+
+__all__ = ['calculate_psnr', 'calculate_ssim']
+
+
+def calculate_metric(data, opt):
+    """Calculate metric from data and options.
+
+    Args:
+        opt (dict): Configuration. It must constain:
+            type (str): Model type.
+    """
+    opt = deepcopy(opt)
+    metric_type = opt.pop('type')
+    metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+    return metric
diff --git a/CodeFormer/basicsr/metrics/metric_util.py b/CodeFormer/basicsr/metrics/metric_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d18f0f7816431bed6af9d58319c6435bdf5c971
--- /dev/null
+++ b/CodeFormer/basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from basicsr.utils.matlab_functions import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+    """Reorder images to 'HWC' order.
+
+    If the input_order is (h, w), return (h, w, 1);
+    If the input_order is (c, h, w), return (h, w, c);
+    If the input_order is (h, w, c), return as it is.
+
+    Args:
+        img (ndarray): Input image.
+        input_order (str): Whether the input order is 'HWC' or 'CHW'.
+            If the input image shape is (h, w), input_order will not have
+            effects. Default: 'HWC'.
+
+    Returns:
+        ndarray: reordered image.
+    """
+
+    if input_order not in ['HWC', 'CHW']:
+        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
+    if len(img.shape) == 2:
+        img = img[..., None]
+    if input_order == 'CHW':
+        img = img.transpose(1, 2, 0)
+    return img
+
+
+def to_y_channel(img):
+    """Change to Y channel of YCbCr.
+
+    Args:
+        img (ndarray): Images with range [0, 255].
+
+    Returns:
+        (ndarray): Images with range [0, 255] (float type) without round.
+    """
+    img = img.astype(np.float32) / 255.
+    if img.ndim == 3 and img.shape[2] == 3:
+        img = bgr2ycbcr(img, y_only=True)
+        img = img[..., None]
+    return img * 255.
diff --git a/CodeFormer/basicsr/metrics/psnr_ssim.py b/CodeFormer/basicsr/metrics/psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbd950699c2495880236883861d9e199f900eae8
--- /dev/null
+++ b/CodeFormer/basicsr/metrics/psnr_ssim.py
@@ -0,0 +1,128 @@
+import cv2
+import numpy as np
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+    """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+    Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+    Args:
+        img1 (ndarray): Images with range [0, 255].
+        img2 (ndarray): Images with range [0, 255].
+        crop_border (int): Cropped pixels in each edge of an image. These
+            pixels are not involved in the PSNR calculation.
+        input_order (str): Whether the input order is 'HWC' or 'CHW'.
+            Default: 'HWC'.
+        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+    Returns:
+        float: psnr result.
+    """
+
+    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+    if input_order not in ['HWC', 'CHW']:
+        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+    img1 = reorder_image(img1, input_order=input_order)
+    img2 = reorder_image(img2, input_order=input_order)
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+
+    if crop_border != 0:
+        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+    if test_y_channel:
+        img1 = to_y_channel(img1)
+        img2 = to_y_channel(img2)
+
+    mse = np.mean((img1 - img2)**2)
+    if mse == 0:
+        return float('inf')
+    return 20. * np.log10(255. / np.sqrt(mse))
+
+
+def _ssim(img1, img2):
+    """Calculate SSIM (structural similarity) for one channel images.
+
+    It is called by func:`calculate_ssim`.
+
+    Args:
+        img1 (ndarray): Images with range [0, 255] with order 'HWC'.
+        img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+    Returns:
+        float: ssim result.
+    """
+
+    C1 = (0.01 * 255)**2
+    C2 = (0.03 * 255)**2
+
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+    kernel = cv2.getGaussianKernel(11, 1.5)
+    window = np.outer(kernel, kernel.transpose())
+
+    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
+    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+    mu1_sq = mu1**2
+    mu2_sq = mu2**2
+    mu1_mu2 = mu1 * mu2
+    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+    return ssim_map.mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+    """Calculate SSIM (structural similarity).
+
+    Ref:
+    Image quality assessment: From error visibility to structural similarity
+
+    The results are the same as that of the official released MATLAB code in
+    https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+    For three-channel images, SSIM is calculated for each channel and then
+    averaged.
+
+    Args:
+        img1 (ndarray): Images with range [0, 255].
+        img2 (ndarray): Images with range [0, 255].
+        crop_border (int): Cropped pixels in each edge of an image. These
+            pixels are not involved in the SSIM calculation.
+        input_order (str): Whether the input order is 'HWC' or 'CHW'.
+            Default: 'HWC'.
+        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+    Returns:
+        float: ssim result.
+    """
+
+    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+    if input_order not in ['HWC', 'CHW']:
+        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+    img1 = reorder_image(img1, input_order=input_order)
+    img2 = reorder_image(img2, input_order=input_order)
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+
+    if crop_border != 0:
+        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+    if test_y_channel:
+        img1 = to_y_channel(img1)
+        img2 = to_y_channel(img2)
+
+    ssims = []
+    for i in range(img1.shape[2]):
+        ssims.append(_ssim(img1[..., i], img2[..., i]))
+    return np.array(ssims).mean()
diff --git a/CodeFormer/basicsr/models/__init__.py b/CodeFormer/basicsr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00bde45f003698a5b15d3517ae47b59ef1d86e0c
--- /dev/null
+++ b/CodeFormer/basicsr/models/__init__.py
@@ -0,0 +1,30 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ['build_model']
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with
+# '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
+
+
+def build_model(opt):
+    """Build model from options.
+
+    Args:
+        opt (dict): Configuration. It must constain:
+            model_type (str): Model type.
+    """
+    opt = deepcopy(opt)
+    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
+    logger = get_root_logger()
+    logger.info(f'Model [{model.__class__.__name__}] is created.')
+    return model
diff --git a/CodeFormer/basicsr/ops/.DS_Store b/CodeFormer/basicsr/ops/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..6df9aa9724b867da52fa2c2fb36399b996340bc2
Binary files /dev/null and b/CodeFormer/basicsr/ops/.DS_Store differ
diff --git a/CodeFormer/basicsr/ops/__init__.py b/CodeFormer/basicsr/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/CodeFormer/basicsr/ops/dcn/__init__.py b/CodeFormer/basicsr/ops/dcn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff
--- /dev/null
+++ b/CodeFormer/basicsr/ops/dcn/__init__.py
@@ -0,0 +1,7 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
+                          modulated_deform_conv)
+
+__all__ = [
+    'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+    'modulated_deform_conv'
+]
diff --git a/CodeFormer/basicsr/ops/dcn/deform_conv.py b/CodeFormer/basicsr/ops/dcn/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..734154f9ed9447d585eae7df6886acb136f8a3cf
--- /dev/null
+++ b/CodeFormer/basicsr/ops/dcn/deform_conv.py
@@ -0,0 +1,377 @@
+import math
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair, _single
+
+try:
+    from . import deform_conv_ext
+except ImportError:
+    import os
+    BASICSR_JIT = os.getenv('BASICSR_JIT')
+    if BASICSR_JIT == 'True':
+        from torch.utils.cpp_extension import load
+        module_path = os.path.dirname(__file__)
+        deform_conv_ext = load(
+            'deform_conv',
+            sources=[
+                os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
+                os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
+                os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
+            ],
+        )
+
+
+class DeformConvFunction(Function):
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                offset,
+                weight,
+                stride=1,
+                padding=0,
+                dilation=1,
+                groups=1,
+                deformable_groups=1,
+                im2col_step=64):
+        if input is not None and input.dim() != 4:
+            raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
+        ctx.stride = _pair(stride)
+        ctx.padding = _pair(padding)
+        ctx.dilation = _pair(dilation)
+        ctx.groups = groups
+        ctx.deformable_groups = deformable_groups
+        ctx.im2col_step = im2col_step
+
+        ctx.save_for_backward(input, offset, weight)
+
+        output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones
+
+        if not input.is_cuda:
+            raise NotImplementedError
+        else:
+            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+            deform_conv_ext.deform_conv_forward(input, weight,
+                                                offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+                                                weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+                                                ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+                                                ctx.deformable_groups, cur_im2col_step)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, offset, weight = ctx.saved_tensors
+
+        grad_input = grad_offset = grad_weight = None
+
+        if not grad_output.is_cuda:
+            raise NotImplementedError
+        else:
+            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+                grad_input = torch.zeros_like(input)
+                grad_offset = torch.zeros_like(offset)
+                deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
+                                                           grad_offset, weight, ctx.bufs_[0], weight.size(3),
+                                                           weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+                                                           ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+                                                           ctx.deformable_groups, cur_im2col_step)
+
+            if ctx.needs_input_grad[2]:
+                grad_weight = torch.zeros_like(weight)
+                deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
+                                                                ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+                                                                weight.size(2), ctx.stride[1], ctx.stride[0],
+                                                                ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+                                                                ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+                                                                cur_im2col_step)
+
+        return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+    @staticmethod
+    def _output_size(input, weight, padding, dilation, stride):
+        channels = weight.size(0)
+        output_size = (input.size(0), channels)
+        for d in range(input.dim() - 2):
+            in_size = input.size(d + 2)
+            pad = padding[d]
+            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+            stride_ = stride[d]
+            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+        if not all(map(lambda s: s > 0, output_size)):
+            raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
+        return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                offset,
+                mask,
+                weight,
+                bias=None,
+                stride=1,
+                padding=0,
+                dilation=1,
+                groups=1,
+                deformable_groups=1):
+        ctx.stride = stride
+        ctx.padding = padding
+        ctx.dilation = dilation
+        ctx.groups = groups
+        ctx.deformable_groups = deformable_groups
+        ctx.with_bias = bias is not None
+        if not ctx.with_bias:
+            bias = input.new_empty(1)  # fake tensor
+        if not input.is_cuda:
+            raise NotImplementedError
+        if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+                or input.requires_grad:
+            ctx.save_for_backward(input, offset, mask, weight, bias)
+        output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+        deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
+                                                      ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+                                                      ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+                                                      ctx.groups, ctx.deformable_groups, ctx.with_bias)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        if not grad_output.is_cuda:
+            raise NotImplementedError
+        input, offset, mask, weight, bias = ctx.saved_tensors
+        grad_input = torch.zeros_like(input)
+        grad_offset = torch.zeros_like(offset)
+        grad_mask = torch.zeros_like(mask)
+        grad_weight = torch.zeros_like(weight)
+        grad_bias = torch.zeros_like(bias)
+        deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+                                                       grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+                                                       grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+                                                       ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+                                                       ctx.groups, ctx.deformable_groups, ctx.with_bias)
+        if not ctx.with_bias:
+            grad_bias = None
+
+        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+
+    @staticmethod
+    def _infer_shape(ctx, input, weight):
+        n = input.size(0)
+        channels_out = weight.size(0)
+        height, width = input.shape[2:4]
+        kernel_h, kernel_w = weight.shape[2:4]
+        height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+        width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+        return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 deformable_groups=1,
+                 bias=False):
+        super(DeformConv, self).__init__()
+
+        assert not bias
+        assert in_channels % groups == 0, \
+            f'in_channels {in_channels} is not divisible by groups {groups}'
+        assert out_channels % groups == 0, \
+            f'out_channels {out_channels} is not divisible ' \
+            f'by groups {groups}'
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = _pair(kernel_size)
+        self.stride = _pair(stride)
+        self.padding = _pair(padding)
+        self.dilation = _pair(dilation)
+        self.groups = groups
+        self.deformable_groups = deformable_groups
+        # enable compatibility with nn.Conv2d
+        self.transposed = False
+        self.output_padding = _single(0)
+
+        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        n = self.in_channels
+        for k in self.kernel_size:
+            n *= k
+        stdv = 1. / math.sqrt(n)
+        self.weight.data.uniform_(-stdv, stdv)
+
+    def forward(self, x, offset):
+        # To fix an assert error in deform_conv_cuda.cpp:128
+        # input image is smaller than kernel
+        input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+        if input_pad:
+            pad_h = max(self.kernel_size[0] - x.size(2), 0)
+            pad_w = max(self.kernel_size[1] - x.size(3), 0)
+            x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+            offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+        out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+                          self.deformable_groups)
+        if input_pad:
+            out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+        return out
+
+
+class DeformConvPack(DeformConv):
+    """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int or tuple[int]): Same as nn.Conv2d.
+        stride (int or tuple[int]): Same as nn.Conv2d.
+        padding (int or tuple[int]): Same as nn.Conv2d.
+        dilation (int or tuple[int]): Same as nn.Conv2d.
+        groups (int): Same as nn.Conv2d.
+        bias (bool or str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+            False.
+    """
+
+    _version = 2
+
+    def __init__(self, *args, **kwargs):
+        super(DeformConvPack, self).__init__(*args, **kwargs)
+
+        self.conv_offset = nn.Conv2d(
+            self.in_channels,
+            self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+            kernel_size=self.kernel_size,
+            stride=_pair(self.stride),
+            padding=_pair(self.padding),
+            dilation=_pair(self.dilation),
+            bias=True)
+        self.init_offset()
+
+    def init_offset(self):
+        self.conv_offset.weight.data.zero_()
+        self.conv_offset.bias.data.zero_()
+
+    def forward(self, x):
+        offset = self.conv_offset(x)
+        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+                           self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 deformable_groups=1,
+                 bias=True):
+        super(ModulatedDeformConv, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = _pair(kernel_size)
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.groups = groups
+        self.deformable_groups = deformable_groups
+        self.with_bias = bias
+        # enable compatibility with nn.Conv2d
+        self.transposed = False
+        self.output_padding = _single(0)
+
+        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+        if bias:
+            self.bias = nn.Parameter(torch.Tensor(out_channels))
+        else:
+            self.register_parameter('bias', None)
+        self.init_weights()
+
+    def init_weights(self):
+        n = self.in_channels
+        for k in self.kernel_size:
+            n *= k
+        stdv = 1. / math.sqrt(n)
+        self.weight.data.uniform_(-stdv, stdv)
+        if self.bias is not None:
+            self.bias.data.zero_()
+
+    def forward(self, x, offset, mask):
+        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+                                     self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+    """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int or tuple[int]): Same as nn.Conv2d.
+        stride (int or tuple[int]): Same as nn.Conv2d.
+        padding (int or tuple[int]): Same as nn.Conv2d.
+        dilation (int or tuple[int]): Same as nn.Conv2d.
+        groups (int): Same as nn.Conv2d.
+        bias (bool or str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+            False.
+    """
+
+    _version = 2
+
+    def __init__(self, *args, **kwargs):
+        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+        self.conv_offset = nn.Conv2d(
+            self.in_channels,
+            self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+            kernel_size=self.kernel_size,
+            stride=_pair(self.stride),
+            padding=_pair(self.padding),
+            dilation=_pair(self.dilation),
+            bias=True)
+        self.init_weights()
+
+    def init_weights(self):
+        super(ModulatedDeformConvPack, self).init_weights()
+        if hasattr(self, 'conv_offset'):
+            self.conv_offset.weight.data.zero_()
+            self.conv_offset.bias.data.zero_()
+
+    def forward(self, x):
+        out = self.conv_offset(x)
+        o1, o2, mask = torch.chunk(out, 3, dim=1)
+        offset = torch.cat((o1, o2), dim=1)
+        mask = torch.sigmoid(mask)
+        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+                                     self.groups, self.deformable_groups)
diff --git a/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5d9424908ed2dbd4ac3cdb98d13e09287a4d2f2d
--- /dev/null
+++ b/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,685 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include <torch/extension.h>
+#include <ATen/DeviceGuard.h>
+
+#include <cmath>
+#include <vector>
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+                       const int channels, const int height, const int width,
+                       const int ksize_h, const int ksize_w, const int pad_h,
+                       const int pad_w, const int stride_h, const int stride_w,
+                       const int dilation_h, const int dilation_w,
+                       const int parallel_imgs, const int deformable_group,
+                       at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+                       const int channels, const int height, const int width,
+                       const int ksize_h, const int ksize_w, const int pad_h,
+                       const int pad_w, const int stride_h, const int stride_w,
+                       const int dilation_h, const int dilation_w,
+                       const int parallel_imgs, const int deformable_group,
+                       at::Tensor grad_im);
+
+void deformable_col2im_coord(
+    const at::Tensor data_col, const at::Tensor data_im,
+    const at::Tensor data_offset, const int channels, const int height,
+    const int width, const int ksize_h, const int ksize_w, const int pad_h,
+    const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int parallel_imgs,
+    const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+    const at::Tensor data_im, const at::Tensor data_offset,
+    const at::Tensor data_mask, const int batch_size, const int channels,
+    const int height_im, const int width_im, const int height_col,
+    const int width_col, const int kernel_h, const int kenerl_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int deformable_group,
+    at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+    const at::Tensor data_col, const at::Tensor data_offset,
+    const at::Tensor data_mask, const int batch_size, const int channels,
+    const int height_im, const int width_im, const int height_col,
+    const int width_col, const int kernel_h, const int kenerl_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int deformable_group,
+    at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+    const at::Tensor data_col, const at::Tensor data_im,
+    const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im,
+    const int width_im, const int height_col, const int width_col,
+    const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+    const int stride_h, const int stride_w, const int dilation_h,
+    const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+    at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+                 at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+                 int padW, int dilationH, int dilationW, int group,
+                 int deformable_group) {
+  TORCH_CHECK(weight.ndimension() == 4,
+           "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+           "but got: %s",
+           weight.ndimension());
+
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+  TORCH_CHECK(kW > 0 && kH > 0,
+           "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+           kW);
+
+  TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+           "kernel size should be consistent with weight, ",
+           "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+           kW, weight.size(2), weight.size(3));
+
+  TORCH_CHECK(dW > 0 && dH > 0,
+           "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+  TORCH_CHECK(
+      dilationW > 0 && dilationH > 0,
+      "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+      dilationH, dilationW);
+
+  int ndim = input.ndimension();
+  int dimf = 0;
+  int dimh = 1;
+  int dimw = 2;
+
+  if (ndim == 4) {
+    dimf++;
+    dimh++;
+    dimw++;
+  }
+
+  TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+           ndim);
+
+  long nInputPlane = weight.size(1) * group;
+  long inputHeight = input.size(dimh);
+  long inputWidth = input.size(dimw);
+  long nOutputPlane = weight.size(0);
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+  TORCH_CHECK(nInputPlane % deformable_group == 0,
+           "input channels must divide deformable group size");
+
+  if (outputWidth < 1 || outputHeight < 1)
+    AT_ERROR(
+        "Given input size: (%ld x %ld x %ld). "
+        "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+        nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+        outputWidth);
+
+  TORCH_CHECK(input.size(1) == nInputPlane,
+           "invalid number of input planes, expected: %d, but got: %d",
+           nInputPlane, input.size(1));
+
+  TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+           "input image is smaller than kernel");
+
+  TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+           "invalid spatial size of offset, expected height: %d width: %d, but "
+           "got height: %d width: %d",
+           outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+  TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+           "invalid number of channels of offset");
+
+  if (gradOutput != NULL) {
+    TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+             "invalid number of gradOutput planes, expected: %d, but got: %d",
+             nOutputPlane, gradOutput->size(dimf));
+
+    TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+              gradOutput->size(dimw) == outputWidth),
+             "invalid size of gradOutput, expected height: %d width: %d , but "
+             "got height: %d width: %d",
+             outputHeight, outputWidth, gradOutput->size(dimh),
+             gradOutput->size(dimw));
+  }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+                             at::Tensor offset, at::Tensor output,
+                             at::Tensor columns, at::Tensor ones, int kW,
+                             int kH, int dW, int dH, int padW, int padH,
+                             int dilationW, int dilationH, int group,
+                             int deformable_group, int im2col_step) {
+  // todo: resize columns to include im2col: done
+  // todo: add im2col_step as input
+  // todo: add new output buffer and transpose it to output (or directly
+  // transpose output) todo: possibly change data indexing because of
+  // parallel_imgs
+
+  shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+              dilationH, dilationW, group, deformable_group);
+  at::DeviceGuard guard(input.device());
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  weight = weight.contiguous();
+
+  int batch = 1;
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input.unsqueeze_(0);
+    offset.unsqueeze_(0);
+  }
+
+  // todo: assert batchsize dividable by im2col_step
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = weight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+  output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+                        outputHeight, outputWidth});
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+    ones = at::ones({outputHeight, outputWidth}, input.options());
+  }
+
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  at::Tensor output_buffer =
+      at::zeros({batchSize / im2col_step, nOutputPlane,
+                 im2col_step * outputHeight, outputWidth},
+                output.options());
+
+  output_buffer = output_buffer.view(
+      {output_buffer.size(0), group, output_buffer.size(1) / group,
+       output_buffer.size(2), output_buffer.size(3)});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, columns);
+
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      output_buffer[elt][g] = output_buffer[elt][g]
+                                  .flatten(1)
+                                  .addmm_(weight[g].flatten(1), columns[g])
+                                  .view_as(output_buffer[elt][g]);
+    }
+  }
+
+  output_buffer = output_buffer.view(
+      {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+       output_buffer.size(3), output_buffer.size(4)});
+
+  output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+                                      im2col_step, outputHeight, outputWidth});
+  output_buffer.transpose_(1, 2);
+  output.copy_(output_buffer);
+  output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    output = output.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+  }
+
+  return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+                                    at::Tensor gradOutput, at::Tensor gradInput,
+                                    at::Tensor gradOffset, at::Tensor weight,
+                                    at::Tensor columns, int kW, int kH, int dW,
+                                    int dH, int padW, int padH, int dilationW,
+                                    int dilationH, int group,
+                                    int deformable_group, int im2col_step) {
+  shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+              dilationH, dilationW, group, deformable_group);
+  at::DeviceGuard guard(input.device());
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  gradOutput = gradOutput.contiguous();
+  weight = weight.contiguous();
+
+  int batch = 1;
+
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input = input.view({1, input.size(0), input.size(1), input.size(2)});
+    offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+    gradOutput = gradOutput.view(
+        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+  }
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = weight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  // change order of grad output
+  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+                                nOutputPlane, outputHeight, outputWidth});
+  gradOutput.transpose_(1, 2);
+
+  gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                              inputHeight, inputWidth});
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+                                deformable_group * 2 * kH * kW, outputHeight,
+                                outputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    // divide into groups
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+    gradOutput = gradOutput.view(
+        {gradOutput.size(0), group, gradOutput.size(1) / group,
+         gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+    for (int g = 0; g < group; g++) {
+      columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+                                     gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    gradOutput = gradOutput.view(
+        {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+         gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+    deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+                            inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+                            dilationH, dilationW, im2col_step, deformable_group,
+                            gradOffset[elt]);
+
+    deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, gradInput[elt]);
+  }
+
+  gradOutput.transpose_(1, 2);
+  gradOutput =
+      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  gradOffset = gradOffset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+    gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+    gradOffset =
+        gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+  }
+
+  return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+    at::Tensor gradWeight,  // at::Tensor gradBias,
+    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+    int padW, int padH, int dilationW, int dilationH, int group,
+    int deformable_group, float scale, int im2col_step) {
+  // todo: transpose and reshape outGrad
+  // todo: reshape columns
+  // todo: add im2col_step as input
+
+  shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+              padW, dilationH, dilationW, group, deformable_group);
+  at::DeviceGuard guard(input.device());
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  gradOutput = gradOutput.contiguous();
+
+  int batch = 1;
+
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input = input.view(
+        at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+    gradOutput = gradOutput.view(
+        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+  }
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = gradWeight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+                                nOutputPlane, outputHeight, outputWidth});
+  gradOutput.transpose_(1, 2);
+
+  at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+  gradOutputBuffer =
+      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+                             outputHeight, outputWidth});
+  gradOutputBuffer.copy_(gradOutput);
+  gradOutputBuffer =
+      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+                             im2col_step * outputHeight, outputWidth});
+
+  gradOutput.transpose_(1, 2);
+  gradOutput =
+      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, columns);
+
+    // divide into group
+    gradOutputBuffer = gradOutputBuffer.view(
+        {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+         gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    gradWeight =
+        gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+                         gradWeight.size(2), gradWeight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      gradWeight[g] = gradWeight[g]
+                          .flatten(1)
+                          .addmm_(gradOutputBuffer[elt][g].flatten(1),
+                                  columns[g].transpose(1, 0), 1.0, scale)
+                          .view_as(gradWeight[g]);
+    }
+    gradOutputBuffer = gradOutputBuffer.view(
+        {gradOutputBuffer.size(0),
+         gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+         gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+                                  gradWeight.size(2), gradWeight.size(3),
+                                  gradWeight.size(4)});
+  }
+
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+  }
+
+  return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+    const int pad_h, const int pad_w, const int dilation_h,
+    const int dilation_w, const int group, const int deformable_group,
+    const bool with_bias) {
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+  at::DeviceGuard guard(input.device());
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+
+  const int channels_out = weight.size(0);
+  const int channels_kernel = weight.size(1);
+  const int kernel_h_ = weight.size(2);
+  const int kernel_w_ = weight.size(3);
+
+  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+    AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+             kernel_h_, kernel_w, kernel_h_, kernel_w_);
+  if (channels != channels_kernel * group)
+    AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+             channels, channels_kernel * group);
+
+  const int height_out =
+      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+  const int width_out =
+      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < height_out * width_out) {
+    // Resize plane and fill with ones...
+    ones = at::ones({height_out, width_out}, input.options());
+  }
+
+  // resize output
+  output = output.view({batch, channels_out, height_out, width_out}).zero_();
+  // resize temporary columns
+  columns =
+      at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+                input.options());
+
+  output = output.view({output.size(0), group, output.size(1) / group,
+                        output.size(2), output.size(3)});
+
+  for (int b = 0; b < batch; b++) {
+    modulated_deformable_im2col_cuda(
+        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, columns);
+
+    // divide into group
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+    for (int g = 0; g < group; g++) {
+      output[b][g] = output[b][g]
+                         .flatten(1)
+                         .addmm_(weight[g].flatten(1), columns[g])
+                         .view_as(output[b][g]);
+    }
+
+    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+                          weight.size(3), weight.size(4)});
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+  }
+
+  output = output.view({output.size(0), output.size(1) * output.size(2),
+                        output.size(3), output.size(4)});
+
+  if (with_bias) {
+    output += bias.view({1, bias.size(0), 1, 1});
+  }
+}
+
+void modulated_deform_conv_cuda_backward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+    const bool with_bias) {
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+  at::DeviceGuard guard(input.device());
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+
+  const int channels_kernel = weight.size(1);
+  const int kernel_h_ = weight.size(2);
+  const int kernel_w_ = weight.size(3);
+  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+    AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+             kernel_h_, kernel_w, kernel_h_, kernel_w_);
+  if (channels != channels_kernel * group)
+    AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+             channels, channels_kernel * group);
+
+  const int height_out =
+      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+  const int width_out =
+      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < height_out * width_out) {
+    // Resize plane and fill with ones...
+    ones = at::ones({height_out, width_out}, input.options());
+  }
+
+  grad_input = grad_input.view({batch, channels, height, width});
+  columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+                      input.options());
+
+  grad_output =
+      grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+                        grad_output.size(2), grad_output.size(3)});
+
+  for (int b = 0; b < batch; b++) {
+    // divide int group
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+                        grad_output[b][g].flatten(1), 0.0f, 1.0f);
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+                          weight.size(3), weight.size(4)});
+
+    // gradient w.r.t. input coordinate data
+    modulated_deformable_col2im_coord_cuda(
+        columns, input[b], offset[b], mask[b], 1, channels, height, width,
+        height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+        stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+        grad_mask[b]);
+    // gradient w.r.t. input data
+    modulated_deformable_col2im_cuda(
+        columns, offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+    // gradient w.r.t. weight, dWeight should accumulate across the batch and
+    // group
+    modulated_deformable_im2col_cuda(
+        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, columns);
+
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+                                    grad_weight.size(1), grad_weight.size(2),
+                                    grad_weight.size(3)});
+    if (with_bias)
+      grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+    for (int g = 0; g < group; g++) {
+      grad_weight[g] =
+          grad_weight[g]
+              .flatten(1)
+              .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+              .view_as(grad_weight[g]);
+      if (with_bias) {
+        grad_bias[g] =
+            grad_bias[g]
+                .view({-1, 1})
+                .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+                .view(-1);
+      }
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+                                    grad_weight.size(2), grad_weight.size(3),
+                                    grad_weight.size(4)});
+    if (with_bias)
+      grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+  }
+  grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+                                  grad_output.size(2), grad_output.size(3),
+                                  grad_output.size(4)});
+}
diff --git a/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..98752dccf8c58817ca1a952554dd3f33188a2d34
--- /dev/null
+++ b/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
@@ -0,0 +1,867 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <THC/THCAtomics.cuh>
+#include <stdio.h>
+#include <math.h>
+#include <float.h>
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n)                                 \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+       i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+
+inline int GET_BLOCKS(const int N)
+{
+  return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+template <typename scalar_t>
+__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+                                               const int height, const int width, scalar_t h, scalar_t w)
+{
+
+  int h_low = floor(h);
+  int w_low = floor(w);
+  int h_high = h_low + 1;
+  int w_high = w_low + 1;
+
+  scalar_t lh = h - h_low;
+  scalar_t lw = w - w_low;
+  scalar_t hh = 1 - lh, hw = 1 - lw;
+
+  scalar_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0)
+    v1 = bottom_data[h_low * data_width + w_low];
+  scalar_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1)
+    v2 = bottom_data[h_low * data_width + w_high];
+  scalar_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0)
+    v3 = bottom_data[h_high * data_width + w_low];
+  scalar_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1)
+    v4 = bottom_data[h_high * data_width + w_high];
+
+  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  return val;
+}
+
+template <typename scalar_t>
+__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+                                        const int h, const int w, const int height, const int width)
+{
+
+  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+  {
+    //empty
+    return 0;
+  }
+
+  int argmax_h_low = floor(argmax_h);
+  int argmax_w_low = floor(argmax_w);
+  int argmax_h_high = argmax_h_low + 1;
+  int argmax_w_high = argmax_w_low + 1;
+
+  scalar_t weight = 0;
+  if (h == argmax_h_low && w == argmax_w_low)
+    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+  if (h == argmax_h_low && w == argmax_w_high)
+    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+  if (h == argmax_h_high && w == argmax_w_low)
+    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+  if (h == argmax_h_high && w == argmax_w_high)
+    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+  return weight;
+}
+
+template <typename scalar_t>
+__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+                                          const int height, const int width, const scalar_t *im_data,
+                                          const int data_width, const int bp_dir)
+{
+
+  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+  {
+    //empty
+    return 0;
+  }
+
+  int argmax_h_low = floor(argmax_h);
+  int argmax_w_low = floor(argmax_w);
+  int argmax_h_high = argmax_h_low + 1;
+  int argmax_w_high = argmax_w_low + 1;
+
+  scalar_t weight = 0;
+
+  if (bp_dir == 0)
+  {
+    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+  }
+  else if (bp_dir == 1)
+  {
+    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+  }
+
+  return weight;
+}
+
+template <typename scalar_t>
+__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+                                             const int height, const int width, const int kernel_h, const int kernel_w,
+                                             const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+                                             const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+                                             const int batch_size, const int num_channels, const int deformable_group,
+                                             const int height_col, const int width_col,
+                                             scalar_t *data_col)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    // index index of output matrix
+    const int w_col = index % width_col;
+    const int h_col = (index / width_col) % height_col;
+    const int b_col = (index / width_col / height_col) % batch_size;
+    const int c_im = (index / width_col / height_col) / batch_size;
+    const int c_col = c_im * kernel_h * kernel_w;
+
+    // compute deformable group index
+    const int deformable_group_index = c_im / channel_per_deformable_group;
+
+    const int h_in = h_col * stride_h - pad_h;
+    const int w_in = w_col * stride_w - pad_w;
+    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+    //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+    for (int i = 0; i < kernel_h; ++i)
+    {
+      for (int j = 0; j < kernel_w; ++j)
+      {
+        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+        scalar_t val = static_cast<scalar_t>(0);
+        const scalar_t h_im = h_in + i * dilation_h + offset_h;
+        const scalar_t w_im = w_in + j * dilation_w + offset_w;
+        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+        {
+          //const scalar_t map_h = i * dilation_h + offset_h;
+          //const scalar_t map_w = j * dilation_w + offset_w;
+          //const int cur_height = height - h_in;
+          //const int cur_width = width - w_in;
+          //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+          val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+        }
+        *data_col_ptr = val;
+        data_col_ptr += batch_size * height_col * width_col;
+      }
+    }
+  }
+}
+
+void deformable_im2col(
+    const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+    const int height, const int width, const int ksize_h, const int ksize_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int parallel_imgs,
+    const int deformable_group, at::Tensor data_col)
+{
+  // num_axes should be smaller than block size
+  // todo: check parallel_imgs is correctly passed in
+  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+  int num_kernels = channels * height_col * width_col * parallel_imgs;
+  int channel_per_deformable_group = channels / deformable_group;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
+        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+
+        deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+            num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+            channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+            height_col, width_col, data_col_);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+  }
+}
+
+template <typename scalar_t>
+__global__ void deformable_col2im_gpu_kernel(
+    const int n, const scalar_t *data_col, const scalar_t *data_offset,
+    const int channels, const int height, const int width,
+    const int kernel_h, const int kernel_w,
+    const int pad_h, const int pad_w,
+    const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int channel_per_deformable_group,
+    const int batch_size, const int deformable_group,
+    const int height_col, const int width_col,
+    scalar_t *grad_im)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    const int j = (index / width_col / height_col / batch_size) % kernel_w;
+    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+    // compute the start and end of the output
+
+    const int deformable_group_index = c / channel_per_deformable_group;
+
+    int w_out = index % width_col;
+    int h_out = (index / width_col) % height_col;
+    int b = (index / width_col / height_col) % batch_size;
+    int w_in = w_out * stride_w - pad_w;
+    int h_in = h_out * stride_h - pad_h;
+
+    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+                                                        2 * kernel_h * kernel_w * height_col * width_col;
+    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+    const scalar_t cur_top_grad = data_col[index];
+    const int cur_h = (int)cur_inv_h_data;
+    const int cur_w = (int)cur_inv_w_data;
+    for (int dy = -2; dy <= 2; dy++)
+    {
+      for (int dx = -2; dx <= 2; dx++)
+      {
+        if (cur_h + dy >= 0 && cur_h + dy < height &&
+            cur_w + dx >= 0 && cur_w + dx < width &&
+            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+            abs(cur_inv_w_data - (cur_w + dx)) < 1)
+        {
+          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+          scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+        }
+      }
+    }
+  }
+}
+
+void deformable_col2im(
+    const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+    const int height, const int width, const int ksize_h,
+    const int ksize_w, const int pad_h, const int pad_w,
+    const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int parallel_imgs, const int deformable_group,
+    at::Tensor grad_im)
+{
+
+  // todo: make sure parallel_imgs is passed in correctly
+  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+  int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+  int channel_per_deformable_group = channels / deformable_group;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
+        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
+
+        deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+            num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+            ksize_w, pad_h, pad_w, stride_h, stride_w,
+            dilation_h, dilation_w, channel_per_deformable_group,
+            parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+  }
+}
+
+template <typename scalar_t>
+__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+                                                   const scalar_t *data_im, const scalar_t *data_offset,
+                                                   const int channels, const int height, const int width,
+                                                   const int kernel_h, const int kernel_w,
+                                                   const int pad_h, const int pad_w,
+                                                   const int stride_h, const int stride_w,
+                                                   const int dilation_h, const int dilation_w,
+                                                   const int channel_per_deformable_group,
+                                                   const int batch_size, const int offset_channels, const int deformable_group,
+                                                   const int height_col, const int width_col, scalar_t *grad_offset)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    scalar_t val = 0;
+    int w = index % width_col;
+    int h = (index / width_col) % height_col;
+    int c = (index / width_col / height_col) % offset_channels;
+    int b = (index / width_col / height_col) / offset_channels;
+    // compute the start and end of the output
+
+    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+    const int col_step = kernel_h * kernel_w;
+    int cnt = 0;
+    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+                                                  batch_size * width_col * height_col;
+    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+                                                channel_per_deformable_group / kernel_h / kernel_w * height * width;
+    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+                                                        kernel_h * kernel_w * height_col * width_col;
+
+    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+    {
+      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+      const int bp_dir = offset_c % 2;
+
+      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+      int w_out = col_pos % width_col;
+      int h_out = (col_pos / width_col) % height_col;
+      int w_in = w_out * stride_w - pad_w;
+      int h_in = h_out * stride_h - pad_h;
+      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+      scalar_t inv_h = h_in + i * dilation_h + offset_h;
+      scalar_t inv_w = w_in + j * dilation_w + offset_w;
+      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+      {
+        inv_h = inv_w = -2;
+      }
+      const scalar_t weight = get_coordinate_weight(
+          inv_h, inv_w,
+          height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+      val += weight * data_col_ptr[col_pos];
+      cnt += 1;
+    }
+
+    grad_offset[index] = val;
+  }
+}
+
+void deformable_col2im_coord(
+    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+    const int channels, const int height, const int width, const int ksize_h,
+    const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+    const int stride_w, const int dilation_h, const int dilation_w,
+    const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+{
+
+  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+  int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+  int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
+        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
+
+        deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+            num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+            ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+            dilation_h, dilation_w, channel_per_deformable_group,
+            parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+            height_col, width_col, grad_offset_);
+      }));
+}
+
+template <typename scalar_t>
+__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+                                         const int height, const int width, scalar_t h, scalar_t w)
+{
+  int h_low = floor(h);
+  int w_low = floor(w);
+  int h_high = h_low + 1;
+  int w_high = w_low + 1;
+
+  scalar_t lh = h - h_low;
+  scalar_t lw = w - w_low;
+  scalar_t hh = 1 - lh, hw = 1 - lw;
+
+  scalar_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0)
+    v1 = bottom_data[h_low * data_width + w_low];
+  scalar_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1)
+    v2 = bottom_data[h_low * data_width + w_high];
+  scalar_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0)
+    v3 = bottom_data[h_high * data_width + w_low];
+  scalar_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1)
+    v4 = bottom_data[h_high * data_width + w_high];
+
+  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  return val;
+}
+
+template <typename scalar_t>
+__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+                                             const int h, const int w, const int height, const int width)
+{
+  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+  {
+    //empty
+    return 0;
+  }
+
+  int argmax_h_low = floor(argmax_h);
+  int argmax_w_low = floor(argmax_w);
+  int argmax_h_high = argmax_h_low + 1;
+  int argmax_w_high = argmax_w_low + 1;
+
+  scalar_t weight = 0;
+  if (h == argmax_h_low && w == argmax_w_low)
+    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+  if (h == argmax_h_low && w == argmax_w_high)
+    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+  if (h == argmax_h_high && w == argmax_w_low)
+    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+  if (h == argmax_h_high && w == argmax_w_high)
+    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+  return weight;
+}
+
+template <typename scalar_t>
+__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+                                               const int height, const int width, const scalar_t *im_data,
+                                               const int data_width, const int bp_dir)
+{
+  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+  {
+    //empty
+    return 0;
+  }
+
+  int argmax_h_low = floor(argmax_h);
+  int argmax_w_low = floor(argmax_w);
+  int argmax_h_high = argmax_h_low + 1;
+  int argmax_w_high = argmax_w_low + 1;
+
+  scalar_t weight = 0;
+
+  if (bp_dir == 0)
+  {
+    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+  }
+  else if (bp_dir == 1)
+  {
+    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+  }
+
+  return weight;
+}
+
+template <typename scalar_t>
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+                                                       const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+                                                       const int height, const int width, const int kernel_h, const int kernel_w,
+                                                       const int pad_h, const int pad_w,
+                                                       const int stride_h, const int stride_w,
+                                                       const int dilation_h, const int dilation_w,
+                                                       const int channel_per_deformable_group,
+                                                       const int batch_size, const int num_channels, const int deformable_group,
+                                                       const int height_col, const int width_col,
+                                                       scalar_t *data_col)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    // index index of output matrix
+    const int w_col = index % width_col;
+    const int h_col = (index / width_col) % height_col;
+    const int b_col = (index / width_col / height_col) % batch_size;
+    const int c_im = (index / width_col / height_col) / batch_size;
+    const int c_col = c_im * kernel_h * kernel_w;
+
+    // compute deformable group index
+    const int deformable_group_index = c_im / channel_per_deformable_group;
+
+    const int h_in = h_col * stride_h - pad_h;
+    const int w_in = w_col * stride_w - pad_w;
+
+    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+    const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+    for (int i = 0; i < kernel_h; ++i)
+    {
+      for (int j = 0; j < kernel_w; ++j)
+      {
+        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+        const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+        scalar_t val = static_cast<scalar_t>(0);
+        const scalar_t h_im = h_in + i * dilation_h + offset_h;
+        const scalar_t w_im = w_in + j * dilation_w + offset_w;
+        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+        {
+          //const float map_h = i * dilation_h + offset_h;
+          //const float map_w = j * dilation_w + offset_w;
+          //const int cur_height = height - h_in;
+          //const int cur_width = width - w_in;
+          //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+          val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+        }
+        *data_col_ptr = val * mask;
+        data_col_ptr += batch_size * height_col * width_col;
+        //data_col_ptr += height_col * width_col;
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+                                                       const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+                                                       const int channels, const int height, const int width,
+                                                       const int kernel_h, const int kernel_w,
+                                                       const int pad_h, const int pad_w,
+                                                       const int stride_h, const int stride_w,
+                                                       const int dilation_h, const int dilation_w,
+                                                       const int channel_per_deformable_group,
+                                                       const int batch_size, const int deformable_group,
+                                                       const int height_col, const int width_col,
+                                                       scalar_t *grad_im)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    const int j = (index / width_col / height_col / batch_size) % kernel_w;
+    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+    // compute the start and end of the output
+
+    const int deformable_group_index = c / channel_per_deformable_group;
+
+    int w_out = index % width_col;
+    int h_out = (index / width_col) % height_col;
+    int b = (index / width_col / height_col) % batch_size;
+    int w_in = w_out * stride_w - pad_w;
+    int h_in = h_out * stride_h - pad_h;
+
+    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+    const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+    const scalar_t cur_top_grad = data_col[index] * mask;
+    const int cur_h = (int)cur_inv_h_data;
+    const int cur_w = (int)cur_inv_w_data;
+    for (int dy = -2; dy <= 2; dy++)
+    {
+      for (int dx = -2; dx <= 2; dx++)
+      {
+        if (cur_h + dy >= 0 && cur_h + dy < height &&
+            cur_w + dx >= 0 && cur_w + dx < width &&
+            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+            abs(cur_inv_w_data - (cur_w + dx)) < 1)
+        {
+          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+          scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+                                                             const scalar_t *data_col, const scalar_t *data_im,
+                                                             const scalar_t *data_offset, const scalar_t *data_mask,
+                                                             const int channels, const int height, const int width,
+                                                             const int kernel_h, const int kernel_w,
+                                                             const int pad_h, const int pad_w,
+                                                             const int stride_h, const int stride_w,
+                                                             const int dilation_h, const int dilation_w,
+                                                             const int channel_per_deformable_group,
+                                                             const int batch_size, const int offset_channels, const int deformable_group,
+                                                             const int height_col, const int width_col,
+                                                             scalar_t *grad_offset, scalar_t *grad_mask)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    scalar_t val = 0, mval = 0;
+    int w = index % width_col;
+    int h = (index / width_col) % height_col;
+    int c = (index / width_col / height_col) % offset_channels;
+    int b = (index / width_col / height_col) / offset_channels;
+    // compute the start and end of the output
+
+    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+    const int col_step = kernel_h * kernel_w;
+    int cnt = 0;
+    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+    {
+      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+      const int bp_dir = offset_c % 2;
+
+      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+      int w_out = col_pos % width_col;
+      int h_out = (col_pos / width_col) % height_col;
+      int w_in = w_out * stride_w - pad_w;
+      int h_in = h_out * stride_h - pad_h;
+      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+      const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+      scalar_t inv_h = h_in + i * dilation_h + offset_h;
+      scalar_t inv_w = w_in + j * dilation_w + offset_w;
+      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+      {
+        inv_h = inv_w = -2;
+      }
+      else
+      {
+        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+      }
+      const scalar_t weight = dmcn_get_coordinate_weight(
+          inv_h, inv_w,
+          height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+      val += weight * data_col_ptr[col_pos] * mask;
+      cnt += 1;
+    }
+    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+    grad_offset[index] = val;
+    if (offset_c % 2 == 0)
+      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+  }
+}
+
+void modulated_deformable_im2col_cuda(
+    const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im, const int width_im,
+    const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int deformable_group, at::Tensor data_col)
+{
+  // num_axes should be smaller than block size
+  const int channel_per_deformable_group = channels / deformable_group;
+  const int num_kernels = channels * batch_size * height_col * width_col;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
+        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+        scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+
+        modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+            num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+            batch_size, channels, deformable_group, height_col, width_col, data_col_);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+  }
+}
+
+void modulated_deformable_col2im_cuda(
+    const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im, const int width_im,
+    const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int deformable_group, at::Tensor grad_im)
+{
+
+  const int channel_per_deformable_group = channels / deformable_group;
+  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
+        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+        scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
+
+        modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+            num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+            kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+            dilation_h, dilation_w, channel_per_deformable_group,
+            batch_size, deformable_group, height_col, width_col, grad_im_);
+      }));
+
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+  }
+}
+
+void modulated_deformable_col2im_coord_cuda(
+    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im, const int width_im,
+    const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w,
+    const int deformable_group,
+    at::Tensor grad_offset, at::Tensor grad_mask)
+{
+  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+      data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+        scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
+        scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
+
+        modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+            num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+            kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+            dilation_h, dilation_w, channel_per_deformable_group,
+            batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+            grad_offset_, grad_mask_);
+      }));
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+  }
+}
diff --git a/CodeFormer/basicsr/ops/dcn/src/deform_conv_ext.cpp b/CodeFormer/basicsr/ops/dcn/src/deform_conv_ext.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..41c6df6f721bd95a525fd6a03dd9882e863de042
--- /dev/null
+++ b/CodeFormer/basicsr/ops/dcn/src/deform_conv_ext.cpp
@@ -0,0 +1,164 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include <torch/extension.h>
+#include <ATen/DeviceGuard.h>
+
+#include <cmath>
+#include <vector>
+
+#define WITH_CUDA  // always use cuda
+#ifdef WITH_CUDA
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+                             at::Tensor offset, at::Tensor output,
+                             at::Tensor columns, at::Tensor ones, int kW,
+                             int kH, int dW, int dH, int padW, int padH,
+                             int dilationW, int dilationH, int group,
+                             int deformable_group, int im2col_step);
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+                                    at::Tensor gradOutput, at::Tensor gradInput,
+                                    at::Tensor gradOffset, at::Tensor weight,
+                                    at::Tensor columns, int kW, int kH, int dW,
+                                    int dH, int padW, int padH, int dilationW,
+                                    int dilationH, int group,
+                                    int deformable_group, int im2col_step);
+
+int deform_conv_backward_parameters_cuda(
+    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+    at::Tensor gradWeight,  // at::Tensor gradBias,
+    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+    int padW, int padH, int dilationW, int dilationH, int group,
+    int deformable_group, float scale, int im2col_step);
+
+void modulated_deform_conv_cuda_forward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+    const int pad_h, const int pad_w, const int dilation_h,
+    const int dilation_w, const int group, const int deformable_group,
+    const bool with_bias);
+
+void modulated_deform_conv_cuda_backward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+    const bool with_bias);
+#endif
+
+int deform_conv_forward(at::Tensor input, at::Tensor weight,
+                             at::Tensor offset, at::Tensor output,
+                             at::Tensor columns, at::Tensor ones, int kW,
+                             int kH, int dW, int dH, int padW, int padH,
+                             int dilationW, int dilationH, int group,
+                             int deformable_group, int im2col_step) {
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return deform_conv_forward_cuda(input, weight, offset, output, columns,
+        ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
+        deformable_group, im2col_step);
+#else
+    AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
+                                    at::Tensor gradOutput, at::Tensor gradInput,
+                                    at::Tensor gradOffset, at::Tensor weight,
+                                    at::Tensor columns, int kW, int kH, int dW,
+                                    int dH, int padW, int padH, int dilationW,
+                                    int dilationH, int group,
+                                    int deformable_group, int im2col_step) {
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return deform_conv_backward_input_cuda(input, offset, gradOutput,
+        gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
+        dilationW, dilationH, group, deformable_group, im2col_step);
+#else
+    AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_parameters(
+    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+    at::Tensor gradWeight,  // at::Tensor gradBias,
+    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+    int padW, int padH, int dilationW, int dilationH, int group,
+    int deformable_group, float scale, int im2col_step) {
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
+        gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
+        dilationH, group, deformable_group, scale, im2col_step);
+#else
+    AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_forward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+    const int pad_h, const int pad_w, const int dilation_h,
+    const int dilation_w, const int group, const int deformable_group,
+    const bool with_bias) {
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
+        offset, mask, output, columns, kernel_h, kernel_w, stride_h,
+        stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
+        deformable_group, with_bias);
+#else
+    AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_backward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+    const bool with_bias) {
+  if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+    return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
+        offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
+        grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
+        pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
+        with_bias);
+#else
+    AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+  }
+  AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("deform_conv_forward", &deform_conv_forward,
+        "deform forward");
+  m.def("deform_conv_backward_input", &deform_conv_backward_input,
+        "deform_conv_backward_input");
+  m.def("deform_conv_backward_parameters",
+        &deform_conv_backward_parameters,
+        "deform_conv_backward_parameters");
+  m.def("modulated_deform_conv_forward",
+        &modulated_deform_conv_forward,
+        "modulated deform conv forward");
+  m.def("modulated_deform_conv_backward",
+        &modulated_deform_conv_backward,
+        "modulated deform conv backward");
+}
diff --git a/CodeFormer/basicsr/ops/fused_act/__init__.py b/CodeFormer/basicsr/ops/fused_act/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422
--- /dev/null
+++ b/CodeFormer/basicsr/ops/fused_act/__init__.py
@@ -0,0 +1,3 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+
+__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
diff --git a/CodeFormer/basicsr/ops/fused_act/fused_act.py b/CodeFormer/basicsr/ops/fused_act/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..588f815e596ab0fc83ab0f9d21426c22ec5ed7c3
--- /dev/null
+++ b/CodeFormer/basicsr/ops/fused_act/fused_act.py
@@ -0,0 +1,89 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+import torch
+from torch import nn
+from torch.autograd import Function
+
+try:
+    from . import fused_act_ext
+except ImportError:
+    import os
+    BASICSR_JIT = os.getenv('BASICSR_JIT')
+    if BASICSR_JIT == 'True':
+        from torch.utils.cpp_extension import load
+        module_path = os.path.dirname(__file__)
+        fused_act_ext = load(
+            'fused',
+            sources=[
+                os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
+                os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
+            ],
+        )
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+
+    @staticmethod
+    def forward(ctx, grad_output, out, negative_slope, scale):
+        ctx.save_for_backward(out)
+        ctx.negative_slope = negative_slope
+        ctx.scale = scale
+
+        empty = grad_output.new_empty(0)
+
+        grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+
+        dim = [0]
+
+        if grad_input.ndim > 2:
+            dim += list(range(2, grad_input.ndim))
+
+        grad_bias = grad_input.sum(dim).detach()
+
+        return grad_input, grad_bias
+
+    @staticmethod
+    def backward(ctx, gradgrad_input, gradgrad_bias):
+        out, = ctx.saved_tensors
+        gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
+                                                    ctx.scale)
+
+        return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input, bias, negative_slope, scale):
+        empty = input.new_empty(0)
+        out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+        ctx.save_for_backward(out)
+        ctx.negative_slope = negative_slope
+        ctx.scale = scale
+
+        return out
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        out, = ctx.saved_tensors
+
+        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+
+        return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+
+    def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+        super().__init__()
+
+        self.bias = nn.Parameter(torch.zeros(channel))
+        self.negative_slope = negative_slope
+        self.scale = scale
+
+    def forward(self, input):
+        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+    return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act.cpp b/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..85ed0a79fb9c75f83470ac834090f03608d998ee
--- /dev/null
+++ b/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act.cpp
@@ -0,0 +1,26 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
+#include <torch/extension.h>
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input,
+                                const torch::Tensor& bias,
+                                const torch::Tensor& refer,
+                                int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input,
+                             const torch::Tensor& bias,
+                             const torch::Tensor& refer,
+                             int act, int grad, float alpha, float scale) {
+    CHECK_CUDA(input);
+    CHECK_CUDA(bias);
+
+    return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
diff --git a/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..54c7ff53ce8306db2b3c582ec7fa6696a38b4df0
--- /dev/null
+++ b/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
@@ -0,0 +1,100 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include <torch/types.h>
+
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+
+template <typename scalar_t>
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+    int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+    int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+    scalar_t zero = 0.0;
+
+    for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+        scalar_t x = p_x[xi];
+
+        if (use_bias) {
+            x += p_b[(xi / step_b) % size_b];
+        }
+
+        scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+        scalar_t y;
+
+        switch (act * 10 + grad) {
+            default:
+            case 10: y = x; break;
+            case 11: y = x; break;
+            case 12: y = 0.0; break;
+
+            case 30: y = (x > 0.0) ? x : x * alpha; break;
+            case 31: y = (ref > 0.0) ? x : x * alpha; break;
+            case 32: y = 0.0; break;
+        }
+
+        out[xi] = y * scale;
+    }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+    int act, int grad, float alpha, float scale) {
+    int curDevice = -1;
+    cudaGetDevice(&curDevice);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+    auto x = input.contiguous();
+    auto b = bias.contiguous();
+    auto ref = refer.contiguous();
+
+    int use_bias = b.numel() ? 1 : 0;
+    int use_ref = ref.numel() ? 1 : 0;
+
+    int size_x = x.numel();
+    int size_b = b.numel();
+    int step_b = 1;
+
+    for (int i = 1 + 1; i < x.dim(); i++) {
+        step_b *= x.size(i);
+    }
+
+    int loop_x = 4;
+    int block_size = 4 * 32;
+    int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+    auto y = torch::empty_like(x);
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+        fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
+            y.data_ptr<scalar_t>(),
+            x.data_ptr<scalar_t>(),
+            b.data_ptr<scalar_t>(),
+            ref.data_ptr<scalar_t>(),
+            act,
+            grad,
+            alpha,
+            scale,
+            loop_x,
+            size_x,
+            step_b,
+            size_b,
+            use_bias,
+            use_ref
+        );
+    });
+
+    return y;
+}
diff --git a/CodeFormer/basicsr/ops/upfirdn2d/__init__.py b/CodeFormer/basicsr/ops/upfirdn2d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd
--- /dev/null
+++ b/CodeFormer/basicsr/ops/upfirdn2d/__init__.py
@@ -0,0 +1,3 @@
+from .upfirdn2d import upfirdn2d
+
+__all__ = ['upfirdn2d']
diff --git a/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..43d0b6783a5b512b55815a291fcac2bebeea31e0
--- /dev/null
+++ b/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
@@ -0,0 +1,24 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
+#include <torch/extension.h>
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+                            int up_x, int up_y, int down_x, int down_y,
+                            int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+                        int up_x, int up_y, int down_x, int down_y,
+                        int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+    CHECK_CUDA(input);
+    CHECK_CUDA(kernel);
+
+    return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
diff --git a/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8870063bae4468deab2e721f0978fe9facfb01b1
--- /dev/null
+++ b/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
@@ -0,0 +1,370 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include <torch/types.h>
+
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+  int c = a / b;
+
+  if (c * b > a) {
+    c--;
+  }
+
+  return c;
+}
+
+struct UpFirDn2DKernelParams {
+  int up_x;
+  int up_y;
+  int down_x;
+  int down_y;
+  int pad_x0;
+  int pad_x1;
+  int pad_y0;
+  int pad_y1;
+
+  int major_dim;
+  int in_h;
+  int in_w;
+  int minor_dim;
+  int kernel_h;
+  int kernel_w;
+  int out_h;
+  int out_w;
+  int loop_major;
+  int loop_x;
+};
+
+template <typename scalar_t>
+__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
+                                       const scalar_t *kernel,
+                                       const UpFirDn2DKernelParams p) {
+  int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
+  int out_y = minor_idx / p.minor_dim;
+  minor_idx -= out_y * p.minor_dim;
+  int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
+  int major_idx_base = blockIdx.z * p.loop_major;
+
+  if (out_x_base >= p.out_w || out_y >= p.out_h ||
+      major_idx_base >= p.major_dim) {
+    return;
+  }
+
+  int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
+  int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
+  int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
+  int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
+
+  for (int loop_major = 0, major_idx = major_idx_base;
+       loop_major < p.loop_major && major_idx < p.major_dim;
+       loop_major++, major_idx++) {
+    for (int loop_x = 0, out_x = out_x_base;
+         loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
+      int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
+      int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
+      int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
+      int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
+
+      const scalar_t *x_p =
+          &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
+                 minor_idx];
+      const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
+      int x_px = p.minor_dim;
+      int k_px = -p.up_x;
+      int x_py = p.in_w * p.minor_dim;
+      int k_py = -p.up_y * p.kernel_w;
+
+      scalar_t v = 0.0f;
+
+      for (int y = 0; y < h; y++) {
+        for (int x = 0; x < w; x++) {
+          v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
+          x_p += x_px;
+          k_p += k_px;
+        }
+
+        x_p += x_py - w * x_px;
+        k_p += k_py - w * k_px;
+      }
+
+      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+          minor_idx] = v;
+    }
+  }
+}
+
+template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
+          int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
+__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
+                                 const scalar_t *kernel,
+                                 const UpFirDn2DKernelParams p) {
+  const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+  const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+  __shared__ volatile float sk[kernel_h][kernel_w];
+  __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+  int minor_idx = blockIdx.x;
+  int tile_out_y = minor_idx / p.minor_dim;
+  minor_idx -= tile_out_y * p.minor_dim;
+  tile_out_y *= tile_out_h;
+  int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+  int major_idx_base = blockIdx.z * p.loop_major;
+
+  if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
+      major_idx_base >= p.major_dim) {
+    return;
+  }
+
+  for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
+       tap_idx += blockDim.x) {
+    int ky = tap_idx / kernel_w;
+    int kx = tap_idx - ky * kernel_w;
+    scalar_t v = 0.0;
+
+    if (kx < p.kernel_w & ky < p.kernel_h) {
+      v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+    }
+
+    sk[ky][kx] = v;
+  }
+
+  for (int loop_major = 0, major_idx = major_idx_base;
+       loop_major < p.loop_major & major_idx < p.major_dim;
+       loop_major++, major_idx++) {
+    for (int loop_x = 0, tile_out_x = tile_out_x_base;
+         loop_x < p.loop_x & tile_out_x < p.out_w;
+         loop_x++, tile_out_x += tile_out_w) {
+      int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+      int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+      int tile_in_x = floor_div(tile_mid_x, up_x);
+      int tile_in_y = floor_div(tile_mid_y, up_y);
+
+      __syncthreads();
+
+      for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
+           in_idx += blockDim.x) {
+        int rel_in_y = in_idx / tile_in_w;
+        int rel_in_x = in_idx - rel_in_y * tile_in_w;
+        int in_x = rel_in_x + tile_in_x;
+        int in_y = rel_in_y + tile_in_y;
+
+        scalar_t v = 0.0;
+
+        if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+          v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
+                        p.minor_dim +
+                    minor_idx];
+        }
+
+        sx[rel_in_y][rel_in_x] = v;
+      }
+
+      __syncthreads();
+      for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
+           out_idx += blockDim.x) {
+        int rel_out_y = out_idx / tile_out_w;
+        int rel_out_x = out_idx - rel_out_y * tile_out_w;
+        int out_x = rel_out_x + tile_out_x;
+        int out_y = rel_out_y + tile_out_y;
+
+        int mid_x = tile_mid_x + rel_out_x * down_x;
+        int mid_y = tile_mid_y + rel_out_y * down_y;
+        int in_x = floor_div(mid_x, up_x);
+        int in_y = floor_div(mid_y, up_y);
+        int rel_in_x = in_x - tile_in_x;
+        int rel_in_y = in_y - tile_in_y;
+        int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+        int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+        scalar_t v = 0.0;
+
+#pragma unroll
+        for (int y = 0; y < kernel_h / up_y; y++)
+#pragma unroll
+          for (int x = 0; x < kernel_w / up_x; x++)
+            v += sx[rel_in_y + y][rel_in_x + x] *
+                 sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+        if (out_x < p.out_w & out_y < p.out_h) {
+          out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+              minor_idx] = v;
+        }
+      }
+    }
+  }
+}
+
+torch::Tensor upfirdn2d_op(const torch::Tensor &input,
+                           const torch::Tensor &kernel, int up_x, int up_y,
+                           int down_x, int down_y, int pad_x0, int pad_x1,
+                           int pad_y0, int pad_y1) {
+  int curDevice = -1;
+  cudaGetDevice(&curDevice);
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+  UpFirDn2DKernelParams p;
+
+  auto x = input.contiguous();
+  auto k = kernel.contiguous();
+
+  p.major_dim = x.size(0);
+  p.in_h = x.size(1);
+  p.in_w = x.size(2);
+  p.minor_dim = x.size(3);
+  p.kernel_h = k.size(0);
+  p.kernel_w = k.size(1);
+  p.up_x = up_x;
+  p.up_y = up_y;
+  p.down_x = down_x;
+  p.down_y = down_y;
+  p.pad_x0 = pad_x0;
+  p.pad_x1 = pad_x1;
+  p.pad_y0 = pad_y0;
+  p.pad_y1 = pad_y1;
+
+  p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
+            p.down_y;
+  p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
+            p.down_x;
+
+  auto out =
+      at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+  int mode = -1;
+
+  int tile_out_h = -1;
+  int tile_out_w = -1;
+
+  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+      p.kernel_h <= 4 && p.kernel_w <= 4) {
+    mode = 1;
+    tile_out_h = 16;
+    tile_out_w = 64;
+  }
+
+  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+      p.kernel_h <= 3 && p.kernel_w <= 3) {
+    mode = 2;
+    tile_out_h = 16;
+    tile_out_w = 64;
+  }
+
+  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+      p.kernel_h <= 4 && p.kernel_w <= 4) {
+    mode = 3;
+    tile_out_h = 16;
+    tile_out_w = 64;
+  }
+
+  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+      p.kernel_h <= 2 && p.kernel_w <= 2) {
+    mode = 4;
+    tile_out_h = 16;
+    tile_out_w = 64;
+  }
+
+  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+      p.kernel_h <= 4 && p.kernel_w <= 4) {
+    mode = 5;
+    tile_out_h = 8;
+    tile_out_w = 32;
+  }
+
+  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+      p.kernel_h <= 2 && p.kernel_w <= 2) {
+    mode = 6;
+    tile_out_h = 8;
+    tile_out_w = 32;
+  }
+
+  dim3 block_size;
+  dim3 grid_size;
+
+  if (tile_out_h > 0 && tile_out_w > 0) {
+    p.loop_major = (p.major_dim - 1) / 16384 + 1;
+    p.loop_x = 1;
+    block_size = dim3(32 * 8, 1, 1);
+    grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+                     (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+                     (p.major_dim - 1) / p.loop_major + 1);
+  } else {
+    p.loop_major = (p.major_dim - 1) / 16384 + 1;
+    p.loop_x = 4;
+    block_size = dim3(4, 32, 1);
+    grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
+                     (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
+                     (p.major_dim - 1) / p.loop_major + 1);
+  }
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+    switch (mode) {
+    case 1:
+      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
+          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+                                                 x.data_ptr<scalar_t>(),
+                                                 k.data_ptr<scalar_t>(), p);
+
+      break;
+
+    case 2:
+      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
+          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+                                                 x.data_ptr<scalar_t>(),
+                                                 k.data_ptr<scalar_t>(), p);
+
+      break;
+
+    case 3:
+      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
+          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+                                                 x.data_ptr<scalar_t>(),
+                                                 k.data_ptr<scalar_t>(), p);
+
+      break;
+
+    case 4:
+      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
+          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+                                                 x.data_ptr<scalar_t>(),
+                                                 k.data_ptr<scalar_t>(), p);
+
+      break;
+
+    case 5:
+      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
+          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+                                                 x.data_ptr<scalar_t>(),
+                                                 k.data_ptr<scalar_t>(), p);
+
+      break;
+
+    case 6:
+      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
+          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+                                                 x.data_ptr<scalar_t>(),
+                                                 k.data_ptr<scalar_t>(), p);
+
+      break;
+
+    default:
+      upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
+          out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
+          k.data_ptr<scalar_t>(), p);
+    }
+  });
+
+  return out;
+}
diff --git a/CodeFormer/basicsr/ops/upfirdn2d/upfirdn2d.py b/CodeFormer/basicsr/ops/upfirdn2d/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..667f96e1ded35d48f163f37e21d1ed8ff191aac3
--- /dev/null
+++ b/CodeFormer/basicsr/ops/upfirdn2d/upfirdn2d.py
@@ -0,0 +1,186 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py  # noqa:E501
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+try:
+    from . import upfirdn2d_ext
+except ImportError:
+    import os
+    BASICSR_JIT = os.getenv('BASICSR_JIT')
+    if BASICSR_JIT == 'True':
+        from torch.utils.cpp_extension import load
+        module_path = os.path.dirname(__file__)
+        upfirdn2d_ext = load(
+            'upfirdn2d',
+            sources=[
+                os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
+                os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
+            ],
+        )
+
+
+class UpFirDn2dBackward(Function):
+
+    @staticmethod
+    def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
+
+        up_x, up_y = up
+        down_x, down_y = down
+        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+        grad_input = upfirdn2d_ext.upfirdn2d(
+            grad_output,
+            grad_kernel,
+            down_x,
+            down_y,
+            up_x,
+            up_y,
+            g_pad_x0,
+            g_pad_x1,
+            g_pad_y0,
+            g_pad_y1,
+        )
+        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+        ctx.save_for_backward(kernel)
+
+        pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+        ctx.up_x = up_x
+        ctx.up_y = up_y
+        ctx.down_x = down_x
+        ctx.down_y = down_y
+        ctx.pad_x0 = pad_x0
+        ctx.pad_x1 = pad_x1
+        ctx.pad_y0 = pad_y0
+        ctx.pad_y1 = pad_y1
+        ctx.in_size = in_size
+        ctx.out_size = out_size
+
+        return grad_input
+
+    @staticmethod
+    def backward(ctx, gradgrad_input):
+        kernel, = ctx.saved_tensors
+
+        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+        gradgrad_out = upfirdn2d_ext.upfirdn2d(
+            gradgrad_input,
+            kernel,
+            ctx.up_x,
+            ctx.up_y,
+            ctx.down_x,
+            ctx.down_y,
+            ctx.pad_x0,
+            ctx.pad_x1,
+            ctx.pad_y0,
+            ctx.pad_y1,
+        )
+        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+        #                                  ctx.out_size[1], ctx.in_size[3])
+        gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
+
+        return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+    @staticmethod
+    def forward(ctx, input, kernel, up, down, pad):
+        up_x, up_y = up
+        down_x, down_y = down
+        pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+        kernel_h, kernel_w = kernel.shape
+        batch, channel, in_h, in_w = input.shape
+        ctx.in_size = input.shape
+
+        input = input.reshape(-1, in_h, in_w, 1)
+
+        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+        ctx.out_size = (out_h, out_w)
+
+        ctx.up = (up_x, up_y)
+        ctx.down = (down_x, down_y)
+        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+        g_pad_x0 = kernel_w - pad_x0 - 1
+        g_pad_y0 = kernel_h - pad_y0 - 1
+        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+        out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
+        # out = out.view(major, out_h, out_w, minor)
+        out = out.view(-1, channel, out_h, out_w)
+
+        return out
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        kernel, grad_kernel = ctx.saved_tensors
+
+        grad_input = UpFirDn2dBackward.apply(
+            grad_output,
+            kernel,
+            grad_kernel,
+            ctx.up,
+            ctx.down,
+            ctx.pad,
+            ctx.g_pad,
+            ctx.in_size,
+            ctx.out_size,
+        )
+
+        return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+    if input.device.type == 'cpu':
+        out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+    else:
+        out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
+
+    return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+    _, channel, in_h, in_w = input.shape
+    input = input.reshape(-1, in_h, in_w, 1)
+
+    _, in_h, in_w, minor = input.shape
+    kernel_h, kernel_w = kernel.shape
+
+    out = input.view(-1, in_h, 1, in_w, 1, minor)
+    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+    out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+    out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+    out = out.permute(0, 3, 1, 2)
+    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+    out = F.conv2d(out, w)
+    out = out.reshape(
+        -1,
+        minor,
+        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+    )
+    out = out.permute(0, 2, 3, 1)
+    out = out[:, ::down_y, ::down_x, :]
+
+    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+    return out.view(-1, channel, out_h, out_w)
diff --git a/CodeFormer/basicsr/setup.py b/CodeFormer/basicsr/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..382a2aa1006e581eaf31dbb3155d4b0ba3b31140
--- /dev/null
+++ b/CodeFormer/basicsr/setup.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python
+
+from setuptools import find_packages, setup
+
+import os
+import subprocess
+import sys
+import time
+import torch
+from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
+
+version_file = './basicsr/version.py'
+
+
+def readme():
+    with open('README.md', encoding='utf-8') as f:
+        content = f.read()
+    return content
+
+
+def get_git_hash():
+
+    def _minimal_ext_cmd(cmd):
+        # construct minimal environment
+        env = {}
+        for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+            v = os.environ.get(k)
+            if v is not None:
+                env[k] = v
+        # LANGUAGE is used on win32
+        env['LANGUAGE'] = 'C'
+        env['LANG'] = 'C'
+        env['LC_ALL'] = 'C'
+        out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+        return out
+
+    try:
+        out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+        sha = out.strip().decode('ascii')
+    except OSError:
+        sha = 'unknown'
+
+    return sha
+
+
+def get_hash():
+    if os.path.exists('.git'):
+        sha = get_git_hash()[:7]
+    elif os.path.exists(version_file):
+        try:
+            from version import __version__
+            sha = __version__.split('+')[-1]
+        except ImportError:
+            raise ImportError('Unable to get git version')
+    else:
+        sha = 'unknown'
+
+    return sha
+
+
+def write_version_py():
+    content = """# GENERATED VERSION FILE
+# TIME: {}
+__version__ = '{}'
+__gitsha__ = '{}'
+version_info = ({})
+"""
+    sha = get_hash()
+    with open('./basicsr/VERSION', 'r') as f:
+        SHORT_VERSION = f.read().strip()
+    VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
+
+    version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
+    with open(version_file, 'w') as f:
+        f.write(version_file_str)
+
+
+def get_version():
+    with open(version_file, 'r') as f:
+        exec(compile(f.read(), version_file, 'exec'))
+    return locals()['__version__']
+
+
+def make_cuda_ext(name, module, sources, sources_cuda=None):
+    if sources_cuda is None:
+        sources_cuda = []
+    define_macros = []
+    extra_compile_args = {'cxx': []}
+
+    if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
+        define_macros += [('WITH_CUDA', None)]
+        extension = CUDAExtension
+        extra_compile_args['nvcc'] = [
+            '-D__CUDA_NO_HALF_OPERATORS__',
+            '-D__CUDA_NO_HALF_CONVERSIONS__',
+            '-D__CUDA_NO_HALF2_OPERATORS__',
+        ]
+        sources += sources_cuda
+    else:
+        print(f'Compiling {name} without CUDA')
+        extension = CppExtension
+
+    return extension(
+        name=f'{module}.{name}',
+        sources=[os.path.join(*module.split('.'), p) for p in sources],
+        define_macros=define_macros,
+        extra_compile_args=extra_compile_args)
+
+
+def get_requirements(filename='requirements.txt'):
+    with open(os.path.join('.', filename), 'r') as f:
+        requires = [line.replace('\n', '') for line in f.readlines()]
+    return requires
+
+
+if __name__ == '__main__':
+    if '--cuda_ext' in sys.argv:
+        ext_modules = [
+            make_cuda_ext(
+                name='deform_conv_ext',
+                module='ops.dcn',
+                sources=['src/deform_conv_ext.cpp'],
+                sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
+            make_cuda_ext(
+                name='fused_act_ext',
+                module='ops.fused_act',
+                sources=['src/fused_bias_act.cpp'],
+                sources_cuda=['src/fused_bias_act_kernel.cu']),
+            make_cuda_ext(
+                name='upfirdn2d_ext',
+                module='ops.upfirdn2d',
+                sources=['src/upfirdn2d.cpp'],
+                sources_cuda=['src/upfirdn2d_kernel.cu']),
+        ]
+        sys.argv.remove('--cuda_ext')
+    else:
+        ext_modules = []
+
+    write_version_py()
+    setup(
+        name='basicsr',
+        version=get_version(),
+        description='Open Source Image and Video Super-Resolution Toolbox',
+        long_description=readme(),
+        long_description_content_type='text/markdown',
+        author='Xintao Wang',
+        author_email='xintao.wang@outlook.com',
+        keywords='computer vision, restoration, super resolution',
+        url='https://github.com/xinntao/BasicSR',
+        include_package_data=True,
+        packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
+        classifiers=[
+            'Development Status :: 4 - Beta',
+            'License :: OSI Approved :: Apache Software License',
+            'Operating System :: OS Independent',
+            'Programming Language :: Python :: 3',
+            'Programming Language :: Python :: 3.7',
+            'Programming Language :: Python :: 3.8',
+        ],
+        license='Apache License 2.0',
+        setup_requires=['cython', 'numpy'],
+        install_requires=get_requirements(),
+        ext_modules=ext_modules,
+        cmdclass={'build_ext': BuildExtension},
+        zip_safe=False)
diff --git a/CodeFormer/basicsr/train.py b/CodeFormer/basicsr/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a01c0dfccdb8b02283100ec5b792c33afaf22f5e
--- /dev/null
+++ b/CodeFormer/basicsr/train.py
@@ -0,0 +1,225 @@
+import argparse
+import datetime
+import logging
+import math
+import copy
+import random
+import time
+import torch
+from os import path as osp
+
+from basicsr.data import build_dataloader, build_dataset
+from basicsr.data.data_sampler import EnlargedSampler
+from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
+from basicsr.models import build_model
+from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
+                           init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed)
+from basicsr.utils.dist_util import get_dist_info, init_dist
+from basicsr.utils.options import dict2str, parse
+
+import warnings
+# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
+warnings.filterwarnings("ignore", category=UserWarning)
+
+def parse_options(root_path, is_train=True):
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
+    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+    parser.add_argument('--local_rank', type=int, default=0)
+    args = parser.parse_args()
+    opt = parse(args.opt, root_path, is_train=is_train)
+
+    # distributed settings
+    if args.launcher == 'none':
+        opt['dist'] = False
+        print('Disable distributed.', flush=True)
+    else:
+        opt['dist'] = True
+        if args.launcher == 'slurm' and 'dist_params' in opt:
+            init_dist(args.launcher, **opt['dist_params'])
+        else:
+            init_dist(args.launcher)
+
+    opt['rank'], opt['world_size'] = get_dist_info()
+
+    # random seed
+    seed = opt.get('manual_seed')
+    if seed is None:
+        seed = random.randint(1, 10000)
+        opt['manual_seed'] = seed
+    set_random_seed(seed + opt['rank'])
+
+    return opt
+
+
+def init_loggers(opt):
+    log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+    logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+    logger.info(get_env_info())
+    logger.info(dict2str(opt))
+
+    # initialize wandb logger before tensorboard logger to allow proper sync:
+    if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
+        assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
+        init_wandb_logger(opt)
+    tb_logger = None
+    if opt['logger'].get('use_tb_logger'):
+        tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
+    return logger, tb_logger
+
+
+def create_train_val_dataloader(opt, logger):
+    # create train and val dataloaders
+    train_loader, val_loader = None, None
+    for phase, dataset_opt in opt['datasets'].items():
+        if phase == 'train':
+            dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
+            train_set = build_dataset(dataset_opt)
+            train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
+            train_loader = build_dataloader(
+                train_set,
+                dataset_opt,
+                num_gpu=opt['num_gpu'],
+                dist=opt['dist'],
+                sampler=train_sampler,
+                seed=opt['manual_seed'])
+
+            num_iter_per_epoch = math.ceil(
+                len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
+            total_iters = int(opt['train']['total_iter'])
+            total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
+            logger.info('Training statistics:'
+                        f'\n\tNumber of train images: {len(train_set)}'
+                        f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
+                        f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
+                        f'\n\tWorld size (gpu number): {opt["world_size"]}'
+                        f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
+                        f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
+
+        elif phase == 'val':
+            val_set = build_dataset(dataset_opt)
+            val_loader = build_dataloader(
+                val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+            logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
+        else:
+            raise ValueError(f'Dataset phase {phase} is not recognized.')
+
+    return train_loader, train_sampler, val_loader, total_epochs, total_iters
+
+
+def train_pipeline(root_path):
+    # parse options, set distributed setting, set ramdom seed
+    opt = parse_options(root_path, is_train=True)
+
+    torch.backends.cudnn.benchmark = True
+    # torch.backends.cudnn.deterministic = True
+
+    # load resume states if necessary
+    if opt['path'].get('resume_state'):
+        device_id = torch.cuda.current_device()
+        resume_state = torch.load(
+            opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
+    else:
+        resume_state = None
+
+    # mkdir for experiments and logger
+    if resume_state is None:
+        make_exp_dirs(opt)
+        if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
+            mkdir_and_rename(osp.join('tb_logger', opt['name']))
+
+    # initialize loggers
+    logger, tb_logger = init_loggers(opt)
+    
+    # create train and validation dataloaders
+    result = create_train_val_dataloader(opt, logger)
+    train_loader, train_sampler, val_loader, total_epochs, total_iters = result
+
+    # create model
+    if resume_state:  # resume training
+        check_resume(opt, resume_state['iter'])
+        model = build_model(opt)
+        model.resume_training(resume_state)  # handle optimizers and schedulers
+        logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
+        start_epoch = resume_state['epoch']
+        current_iter = resume_state['iter']
+    else:
+        model = build_model(opt)
+        start_epoch = 0
+        current_iter = 0
+
+    # create message logger (formatted outputs)
+    msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+    # dataloader prefetcher
+    prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
+    if prefetch_mode is None or prefetch_mode == 'cpu':
+        prefetcher = CPUPrefetcher(train_loader)
+    elif prefetch_mode == 'cuda':
+        prefetcher = CUDAPrefetcher(train_loader, opt)
+        logger.info(f'Use {prefetch_mode} prefetch dataloader')
+        if opt['datasets']['train'].get('pin_memory') is not True:
+            raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
+    else:
+        raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
+
+    # training
+    logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
+    data_time, iter_time = time.time(), time.time()
+    start_time = time.time()
+
+    for epoch in range(start_epoch, total_epochs + 1):
+        train_sampler.set_epoch(epoch)
+        prefetcher.reset()
+        train_data = prefetcher.next()
+
+        while train_data is not None:
+            data_time = time.time() - data_time
+
+            current_iter += 1
+            if current_iter > total_iters:
+                break
+            # update learning rate
+            model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
+            # training
+            model.feed_data(train_data)
+            model.optimize_parameters(current_iter)
+            iter_time = time.time() - iter_time
+            # log
+            if current_iter % opt['logger']['print_freq'] == 0:
+                log_vars = {'epoch': epoch, 'iter': current_iter}
+                log_vars.update({'lrs': model.get_current_learning_rate()})
+                log_vars.update({'time': iter_time, 'data_time': data_time})
+                log_vars.update(model.get_current_log())
+                msg_logger(log_vars)
+
+            # save models and training states
+            if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
+                logger.info('Saving models and training states.')
+                model.save(epoch, current_iter)
+
+            # validation
+            if opt.get('val') is not None and opt['datasets'].get('val') is not None \
+                and (current_iter % opt['val']['val_freq'] == 0):
+                model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+
+            data_time = time.time()
+            iter_time = time.time()
+            train_data = prefetcher.next()
+        # end of iter
+
+    # end of epoch
+
+    consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
+    logger.info(f'End of training. Time consumed: {consumed_time}')
+    logger.info('Save the latest model.')
+    model.save(epoch=-1, current_iter=-1)  # -1 stands for the latest
+    if opt.get('val') is not None and opt['datasets'].get('val'):
+        model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+    if tb_logger:
+        tb_logger.close()
+
+
+if __name__ == '__main__':
+    root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+    train_pipeline(root_path)
diff --git a/CodeFormer/basicsr/utils/__init__.py b/CodeFormer/basicsr/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fcc1d540462712387523d1e326d1dfc2bcfbf32
--- /dev/null
+++ b/CodeFormer/basicsr/utils/__init__.py
@@ -0,0 +1,29 @@
+from .file_client import FileClient
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
+from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
+from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
+
+__all__ = [
+    # file_client.py
+    'FileClient',
+    # img_util.py
+    'img2tensor',
+    'tensor2img',
+    'imfrombytes',
+    'imwrite',
+    'crop_border',
+    # logger.py
+    'MessageLogger',
+    'init_tb_logger',
+    'init_wandb_logger',
+    'get_root_logger',
+    'get_env_info',
+    # misc.py
+    'set_random_seed',
+    'get_time_str',
+    'mkdir_and_rename',
+    'make_exp_dirs',
+    'scandir',
+    'check_resume',
+    'sizeof_fmt'
+]
diff --git a/CodeFormer/basicsr/utils/dist_util.py b/CodeFormer/basicsr/utils/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0
--- /dev/null
+++ b/CodeFormer/basicsr/utils/dist_util.py
@@ -0,0 +1,82 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py  # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+    if mp.get_start_method(allow_none=True) is None:
+        mp.set_start_method('spawn')
+    if launcher == 'pytorch':
+        _init_dist_pytorch(backend, **kwargs)
+    elif launcher == 'slurm':
+        _init_dist_slurm(backend, **kwargs)
+    else:
+        raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+    rank = int(os.environ['RANK'])
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(rank % num_gpus)
+    dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+    """Initialize slurm distributed training environment.
+
+    If argument ``port`` is not specified, then the master port will be system
+    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+    environment variable, then a default port ``29500`` will be used.
+
+    Args:
+        backend (str): Backend of torch.distributed.
+        port (int, optional): Master port. Defaults to None.
+    """
+    proc_id = int(os.environ['SLURM_PROCID'])
+    ntasks = int(os.environ['SLURM_NTASKS'])
+    node_list = os.environ['SLURM_NODELIST']
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(proc_id % num_gpus)
+    addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+    # specify master port
+    if port is not None:
+        os.environ['MASTER_PORT'] = str(port)
+    elif 'MASTER_PORT' in os.environ:
+        pass  # use MASTER_PORT in the environment variable
+    else:
+        # 29500 is torch.distributed default port
+        os.environ['MASTER_PORT'] = '29500'
+    os.environ['MASTER_ADDR'] = addr
+    os.environ['WORLD_SIZE'] = str(ntasks)
+    os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+    os.environ['RANK'] = str(proc_id)
+    dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+    if dist.is_available():
+        initialized = dist.is_initialized()
+    else:
+        initialized = False
+    if initialized:
+        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+    else:
+        rank = 0
+        world_size = 1
+    return rank, world_size
+
+
+def master_only(func):
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        rank, _ = get_dist_info()
+        if rank == 0:
+            return func(*args, **kwargs)
+
+    return wrapper
diff --git a/CodeFormer/basicsr/utils/download_util.py b/CodeFormer/basicsr/utils/download_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a267915743ee3f3232bc8fe992466b52468979a
--- /dev/null
+++ b/CodeFormer/basicsr/utils/download_util.py
@@ -0,0 +1,95 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+from .misc import sizeof_fmt
+
+
+def download_file_from_google_drive(file_id, save_path):
+    """Download files from google drive.
+    Ref:
+    https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive  # noqa E501
+    Args:
+        file_id (str): File id.
+        save_path (str): Save path.
+    """
+
+    session = requests.Session()
+    URL = 'https://docs.google.com/uc?export=download'
+    params = {'id': file_id}
+
+    response = session.get(URL, params=params, stream=True)
+    token = get_confirm_token(response)
+    if token:
+        params['confirm'] = token
+        response = session.get(URL, params=params, stream=True)
+
+    # get file size
+    response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+    print(response_file_size)
+    if 'Content-Range' in response_file_size.headers:
+        file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+    else:
+        file_size = None
+
+    save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+    for key, value in response.cookies.items():
+        if key.startswith('download_warning'):
+            return value
+    return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+    if file_size is not None:
+        pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+        readable_file_size = sizeof_fmt(file_size)
+    else:
+        pbar = None
+
+    with open(destination, 'wb') as f:
+        downloaded_size = 0
+        for chunk in response.iter_content(chunk_size):
+            downloaded_size += chunk_size
+            if pbar is not None:
+                pbar.update(1)
+                pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+            if chunk:  # filter out keep-alive new chunks
+                f.write(chunk)
+        if pbar is not None:
+            pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+    """Load file form http url, will download models if necessary.
+    Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+    Args:
+        url (str): URL to be downloaded.
+        model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+            Default: None.
+        progress (bool): Whether to show the download progress. Default: True.
+        file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+    Returns:
+        str: The path to the downloaded file.
+    """
+    if model_dir is None:  # use the pytorch hub_dir
+        hub_dir = get_dir()
+        model_dir = os.path.join(hub_dir, 'checkpoints')
+
+    os.makedirs(model_dir, exist_ok=True)
+
+    parts = urlparse(url)
+    filename = os.path.basename(parts.path)
+    if file_name is not None:
+        filename = file_name
+    cached_file = os.path.abspath(os.path.join(model_dir, filename))
+    if not os.path.exists(cached_file):
+        print(f'Downloading: "{url}" to {cached_file}\n')
+        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+    return cached_file
\ No newline at end of file
diff --git a/CodeFormer/basicsr/utils/file_client.py b/CodeFormer/basicsr/utils/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f38d9796da3899048924f2f803d1088927966b0
--- /dev/null
+++ b/CodeFormer/basicsr/utils/file_client.py
@@ -0,0 +1,167 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py  # noqa: E501
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+    """Abstract class of storage backends.
+
+    All backends need to implement two apis: ``get()`` and ``get_text()``.
+    ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+    as texts.
+    """
+
+    @abstractmethod
+    def get(self, filepath):
+        pass
+
+    @abstractmethod
+    def get_text(self, filepath):
+        pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+    """Memcached storage backend.
+
+    Attributes:
+        server_list_cfg (str): Config file for memcached server list.
+        client_cfg (str): Config file for memcached client.
+        sys_path (str | None): Additional path to be appended to `sys.path`.
+            Default: None.
+    """
+
+    def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+        if sys_path is not None:
+            import sys
+            sys.path.append(sys_path)
+        try:
+            import mc
+        except ImportError:
+            raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+        self.server_list_cfg = server_list_cfg
+        self.client_cfg = client_cfg
+        self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+        # mc.pyvector servers as a point which points to a memory cache
+        self._mc_buffer = mc.pyvector()
+
+    def get(self, filepath):
+        filepath = str(filepath)
+        import mc
+        self._client.Get(filepath, self._mc_buffer)
+        value_buf = mc.ConvertBuffer(self._mc_buffer)
+        return value_buf
+
+    def get_text(self, filepath):
+        raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+    """Raw hard disks storage backend."""
+
+    def get(self, filepath):
+        filepath = str(filepath)
+        with open(filepath, 'rb') as f:
+            value_buf = f.read()
+        return value_buf
+
+    def get_text(self, filepath):
+        filepath = str(filepath)
+        with open(filepath, 'r') as f:
+            value_buf = f.read()
+        return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+    """Lmdb storage backend.
+
+    Args:
+        db_paths (str | list[str]): Lmdb database paths.
+        client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+        readonly (bool, optional): Lmdb environment parameter. If True,
+            disallow any write operations. Default: True.
+        lock (bool, optional): Lmdb environment parameter. If False, when
+            concurrent access occurs, do not lock the database. Default: False.
+        readahead (bool, optional): Lmdb environment parameter. If False,
+            disable the OS filesystem readahead mechanism, which may improve
+            random read performance when a database is larger than RAM.
+            Default: False.
+
+    Attributes:
+        db_paths (list): Lmdb database path.
+        _client (list): A list of several lmdb envs.
+    """
+
+    def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+        try:
+            import lmdb
+        except ImportError:
+            raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+        if isinstance(client_keys, str):
+            client_keys = [client_keys]
+
+        if isinstance(db_paths, list):
+            self.db_paths = [str(v) for v in db_paths]
+        elif isinstance(db_paths, str):
+            self.db_paths = [str(db_paths)]
+        assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+                                                        f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+        self._client = {}
+        for client, path in zip(client_keys, self.db_paths):
+            self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+    def get(self, filepath, client_key):
+        """Get values according to the filepath from one lmdb named client_key.
+
+        Args:
+            filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+            client_key (str): Used for distinguishing differnet lmdb envs.
+        """
+        filepath = str(filepath)
+        assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
+        client = self._client[client_key]
+        with client.begin(write=False) as txn:
+            value_buf = txn.get(filepath.encode('ascii'))
+        return value_buf
+
+    def get_text(self, filepath):
+        raise NotImplementedError
+
+
+class FileClient(object):
+    """A general file client to access files in different backend.
+
+    The client loads a file or text in a specified backend from its path
+    and return it as a binary file. it can also register other backend
+    accessor with a given name and backend class.
+
+    Attributes:
+        backend (str): The storage backend type. Options are "disk",
+            "memcached" and "lmdb".
+        client (:obj:`BaseStorageBackend`): The backend object.
+    """
+
+    _backends = {
+        'disk': HardDiskBackend,
+        'memcached': MemcachedBackend,
+        'lmdb': LmdbBackend,
+    }
+
+    def __init__(self, backend='disk', **kwargs):
+        if backend not in self._backends:
+            raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+                             f' are {list(self._backends.keys())}')
+        self.backend = backend
+        self.client = self._backends[backend](**kwargs)
+
+    def get(self, filepath, client_key='default'):
+        # client_key is used only for lmdb, where different fileclients have
+        # different lmdb environments.
+        if self.backend == 'lmdb':
+            return self.client.get(filepath, client_key)
+        else:
+            return self.client.get(filepath)
+
+    def get_text(self, filepath):
+        return self.client.get_text(filepath)
diff --git a/CodeFormer/basicsr/utils/img_util.py b/CodeFormer/basicsr/utils/img_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..d409a132ff216e6943a276fb5d8cd5f410824883
--- /dev/null
+++ b/CodeFormer/basicsr/utils/img_util.py
@@ -0,0 +1,170 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+    """Numpy array to tensor.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Input images.
+        bgr2rgb (bool): Whether to change bgr to rgb.
+        float32 (bool): Whether to change to float32.
+
+    Returns:
+        list[tensor] | tensor: Tensor images. If returned results only have
+            one element, just return tensor.
+    """
+
+    def _totensor(img, bgr2rgb, float32):
+        if img.shape[2] == 3 and bgr2rgb:
+            if img.dtype == 'float64':
+                img = img.astype('float32')
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = torch.from_numpy(img.transpose(2, 0, 1))
+        if float32:
+            img = img.float()
+        return img
+
+    if isinstance(imgs, list):
+        return [_totensor(img, bgr2rgb, float32) for img in imgs]
+    else:
+        return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+    """Convert torch Tensors into image numpy arrays.
+
+    After clamping to [min, max], values will be normalized to [0, 1].
+
+    Args:
+        tensor (Tensor or list[Tensor]): Accept shapes:
+            1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+            2) 3D Tensor of shape (3/1 x H x W);
+            3) 2D Tensor of shape (H x W).
+            Tensor channel should be in RGB order.
+        rgb2bgr (bool): Whether to change rgb to bgr.
+        out_type (numpy type): output types. If ``np.uint8``, transform outputs
+            to uint8 type with range [0, 255]; otherwise, float type with
+            range [0, 1]. Default: ``np.uint8``.
+        min_max (tuple[int]): min and max values for clamp.
+
+    Returns:
+        (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+        shape (H x W). The channel order is BGR.
+    """
+    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+    if torch.is_tensor(tensor):
+        tensor = [tensor]
+    result = []
+    for _tensor in tensor:
+        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+        n_dim = _tensor.dim()
+        if n_dim == 4:
+            img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+            img_np = img_np.transpose(1, 2, 0)
+            if rgb2bgr:
+                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+        elif n_dim == 3:
+            img_np = _tensor.numpy()
+            img_np = img_np.transpose(1, 2, 0)
+            if img_np.shape[2] == 1:  # gray image
+                img_np = np.squeeze(img_np, axis=2)
+            else:
+                if rgb2bgr:
+                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+        elif n_dim == 2:
+            img_np = _tensor.numpy()
+        else:
+            raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
+        if out_type == np.uint8:
+            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+            img_np = (img_np * 255.0).round()
+        img_np = img_np.astype(out_type)
+        result.append(img_np)
+    if len(result) == 1:
+        result = result[0]
+    return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+    """This implementation is slightly faster than tensor2img.
+    It now only supports torch tensor with shape (1, c, h, w).
+
+    Args:
+        tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+        rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+        min_max (tuple[int]): min and max values for clamp.
+    """
+    output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+    output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+    output = output.type(torch.uint8).cpu().numpy()
+    if rgb2bgr:
+        output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+    return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+    """Read an image from bytes.
+
+    Args:
+        content (bytes): Image bytes got from files or other streams.
+        flag (str): Flags specifying the color type of a loaded image,
+            candidates are `color`, `grayscale` and `unchanged`.
+        float32 (bool): Whether to change to float32., If True, will also norm
+            to [0, 1]. Default: False.
+
+    Returns:
+        ndarray: Loaded image array.
+    """
+    img_np = np.frombuffer(content, np.uint8)
+    imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+    img = cv2.imdecode(img_np, imread_flags[flag])
+    if float32:
+        img = img.astype(np.float32) / 255.
+    return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+    """Write image to file.
+
+    Args:
+        img (ndarray): Image array to be written.
+        file_path (str): Image file path.
+        params (None or list): Same as opencv's :func:`imwrite` interface.
+        auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+            whether to create it automatically.
+
+    Returns:
+        bool: Successful or not.
+    """
+    if auto_mkdir:
+        dir_name = os.path.abspath(os.path.dirname(file_path))
+        os.makedirs(dir_name, exist_ok=True)
+    return cv2.imwrite(file_path, img, params)
+
+
+def crop_border(imgs, crop_border):
+    """Crop borders of images.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+        crop_border (int): Crop border for each end of height and weight.
+
+    Returns:
+        list[ndarray]: Cropped images.
+    """
+    if crop_border == 0:
+        return imgs
+    else:
+        if isinstance(imgs, list):
+            return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+        else:
+            return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
diff --git a/CodeFormer/basicsr/utils/lmdb_util.py b/CodeFormer/basicsr/utils/lmdb_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0a10f60ffca2e36ac5f5564aafd70e79d06a723
--- /dev/null
+++ b/CodeFormer/basicsr/utils/lmdb_util.py
@@ -0,0 +1,196 @@
+import cv2
+import lmdb
+import sys
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(data_path,
+                        lmdb_path,
+                        img_path_list,
+                        keys,
+                        batch=5000,
+                        compress_level=1,
+                        multiprocessing_read=False,
+                        n_thread=40,
+                        map_size=None):
+    """Make lmdb from images.
+
+    Contents of lmdb. The file structure is:
+    example.lmdb
+    ├── data.mdb
+    ├── lock.mdb
+    ├── meta_info.txt
+
+    The data.mdb and lock.mdb are standard lmdb files and you can refer to
+    https://lmdb.readthedocs.io/en/release/ for more details.
+
+    The meta_info.txt is a specified txt file to record the meta information
+    of our datasets. It will be automatically created when preparing
+    datasets by our provided dataset tools.
+    Each line in the txt file records 1)image name (with extension),
+    2)image shape, and 3)compression level, separated by a white space.
+
+    For example, the meta information could be:
+    `000_00000000.png (720,1280,3) 1`, which means:
+    1) image name (with extension): 000_00000000.png;
+    2) image shape: (720,1280,3);
+    3) compression level: 1
+
+    We use the image name without extension as the lmdb key.
+
+    If `multiprocessing_read` is True, it will read all the images to memory
+    using multiprocessing. Thus, your server needs to have enough memory.
+
+    Args:
+        data_path (str): Data path for reading images.
+        lmdb_path (str): Lmdb save path.
+        img_path_list (str): Image path list.
+        keys (str): Used for lmdb keys.
+        batch (int): After processing batch images, lmdb commits.
+            Default: 5000.
+        compress_level (int): Compress level when encoding images. Default: 1.
+        multiprocessing_read (bool): Whether use multiprocessing to read all
+            the images to memory. Default: False.
+        n_thread (int): For multiprocessing.
+        map_size (int | None): Map size for lmdb env. If None, use the
+            estimated size from images. Default: None
+    """
+
+    assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
+                                             f'but got {len(img_path_list)} and {len(keys)}')
+    print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
+    print(f'Totoal images: {len(img_path_list)}')
+    if not lmdb_path.endswith('.lmdb'):
+        raise ValueError("lmdb_path must end with '.lmdb'.")
+    if osp.exists(lmdb_path):
+        print(f'Folder {lmdb_path} already exists. Exit.')
+        sys.exit(1)
+
+    if multiprocessing_read:
+        # read all the images to memory (multiprocessing)
+        dataset = {}  # use dict to keep the order for multiprocessing
+        shapes = {}
+        print(f'Read images with multiprocessing, #thread: {n_thread} ...')
+        pbar = tqdm(total=len(img_path_list), unit='image')
+
+        def callback(arg):
+            """get the image data and update pbar."""
+            key, dataset[key], shapes[key] = arg
+            pbar.update(1)
+            pbar.set_description(f'Read {key}')
+
+        pool = Pool(n_thread)
+        for path, key in zip(img_path_list, keys):
+            pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+        pool.close()
+        pool.join()
+        pbar.close()
+        print(f'Finish reading {len(img_path_list)} images.')
+
+    # create lmdb environment
+    if map_size is None:
+        # obtain data size for one image
+        img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+        _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+        data_size_per_img = img_byte.nbytes
+        print('Data size per image is: ', data_size_per_img)
+        data_size = data_size_per_img * len(img_path_list)
+        map_size = data_size * 10
+
+    env = lmdb.open(lmdb_path, map_size=map_size)
+
+    # write data to lmdb
+    pbar = tqdm(total=len(img_path_list), unit='chunk')
+    txn = env.begin(write=True)
+    txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+    for idx, (path, key) in enumerate(zip(img_path_list, keys)):
+        pbar.update(1)
+        pbar.set_description(f'Write {key}')
+        key_byte = key.encode('ascii')
+        if multiprocessing_read:
+            img_byte = dataset[key]
+            h, w, c = shapes[key]
+        else:
+            _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+            h, w, c = img_shape
+
+        txn.put(key_byte, img_byte)
+        # write meta information
+        txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
+        if idx % batch == 0:
+            txn.commit()
+            txn = env.begin(write=True)
+    pbar.close()
+    txn.commit()
+    env.close()
+    txt_file.close()
+    print('\nFinish writing lmdb.')
+
+
+def read_img_worker(path, key, compress_level):
+    """Read image worker.
+
+    Args:
+        path (str): Image path.
+        key (str): Image key.
+        compress_level (int): Compress level when encoding images.
+
+    Returns:
+        str: Image key.
+        byte: Image byte.
+        tuple[int]: Image shape.
+    """
+
+    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+    if img.ndim == 2:
+        h, w = img.shape
+        c = 1
+    else:
+        h, w, c = img.shape
+    _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+    return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker():
+    """LMDB Maker.
+
+    Args:
+        lmdb_path (str): Lmdb save path.
+        map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+        batch (int): After processing batch images, lmdb commits.
+            Default: 5000.
+        compress_level (int): Compress level when encoding images. Default: 1.
+    """
+
+    def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+        if not lmdb_path.endswith('.lmdb'):
+            raise ValueError("lmdb_path must end with '.lmdb'.")
+        if osp.exists(lmdb_path):
+            print(f'Folder {lmdb_path} already exists. Exit.')
+            sys.exit(1)
+
+        self.lmdb_path = lmdb_path
+        self.batch = batch
+        self.compress_level = compress_level
+        self.env = lmdb.open(lmdb_path, map_size=map_size)
+        self.txn = self.env.begin(write=True)
+        self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+        self.counter = 0
+
+    def put(self, img_byte, key, img_shape):
+        self.counter += 1
+        key_byte = key.encode('ascii')
+        self.txn.put(key_byte, img_byte)
+        # write meta information
+        h, w, c = img_shape
+        self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
+        if self.counter % self.batch == 0:
+            self.txn.commit()
+            self.txn = self.env.begin(write=True)
+
+    def close(self):
+        self.txn.commit()
+        self.env.close()
+        self.txt_file.close()
diff --git a/CodeFormer/basicsr/utils/logger.py b/CodeFormer/basicsr/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..9714bf59c30fc82de24c1ee58d9118d0864b3572
--- /dev/null
+++ b/CodeFormer/basicsr/utils/logger.py
@@ -0,0 +1,169 @@
+import datetime
+import logging
+import time
+
+from .dist_util import get_dist_info, master_only
+
+initialized_logger = {}
+
+
+class MessageLogger():
+    """Message logger for printing.
+    Args:
+        opt (dict): Config. It contains the following keys:
+            name (str): Exp name.
+            logger (dict): Contains 'print_freq' (str) for logger interval.
+            train (dict): Contains 'total_iter' (int) for total iters.
+            use_tb_logger (bool): Use tensorboard logger.
+        start_iter (int): Start iter. Default: 1.
+        tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+    """
+
+    def __init__(self, opt, start_iter=1, tb_logger=None):
+        self.exp_name = opt['name']
+        self.interval = opt['logger']['print_freq']
+        self.start_iter = start_iter
+        self.max_iters = opt['train']['total_iter']
+        self.use_tb_logger = opt['logger']['use_tb_logger']
+        self.tb_logger = tb_logger
+        self.start_time = time.time()
+        self.logger = get_root_logger()
+
+    @master_only
+    def __call__(self, log_vars):
+        """Format logging message.
+        Args:
+            log_vars (dict): It contains the following keys:
+                epoch (int): Epoch number.
+                iter (int): Current iter.
+                lrs (list): List for learning rates.
+                time (float): Iter time.
+                data_time (float): Data time for each iter.
+        """
+        # epoch, iter, learning rates
+        epoch = log_vars.pop('epoch')
+        current_iter = log_vars.pop('iter')
+        lrs = log_vars.pop('lrs')
+
+        message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
+        for v in lrs:
+            message += f'{v:.3e},'
+        message += ')] '
+
+        # time and estimated time
+        if 'time' in log_vars.keys():
+            iter_time = log_vars.pop('time')
+            data_time = log_vars.pop('data_time')
+
+            total_time = time.time() - self.start_time
+            time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+            eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+            eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+            message += f'[eta: {eta_str}, '
+            message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
+
+        # other items, especially losses
+        for k, v in log_vars.items():
+            message += f'{k}: {v:.4e} '
+            # tensorboard logger
+            if self.use_tb_logger:
+                if k.startswith('l_'):
+                    self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
+                else:
+                    self.tb_logger.add_scalar(k, v, current_iter)
+        self.logger.info(message)
+
+
+@master_only
+def init_tb_logger(log_dir):
+    from torch.utils.tensorboard import SummaryWriter
+    tb_logger = SummaryWriter(log_dir=log_dir)
+    return tb_logger
+
+
+@master_only
+def init_wandb_logger(opt):
+    """We now only use wandb to sync tensorboard log."""
+    import wandb
+    logger = logging.getLogger('basicsr')
+
+    project = opt['logger']['wandb']['project']
+    resume_id = opt['logger']['wandb'].get('resume_id')
+    if resume_id:
+        wandb_id = resume_id
+        resume = 'allow'
+        logger.warning(f'Resume wandb logger with id={wandb_id}.')
+    else:
+        wandb_id = wandb.util.generate_id()
+        resume = 'never'
+
+    wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
+
+    logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
+
+
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+    """Get the root logger.
+    The logger will be initialized if it has not been initialized. By default a
+    StreamHandler will be added. If `log_file` is specified, a FileHandler will
+    also be added.
+    Args:
+        logger_name (str): root logger name. Default: 'basicsr'.
+        log_file (str | None): The log filename. If specified, a FileHandler
+            will be added to the root logger.
+        log_level (int): The root logger level. Note that only the process of
+            rank 0 is affected, while other processes will set the level to
+            "Error" and be silent most of the time.
+    Returns:
+        logging.Logger: The root logger.
+    """
+    logger = logging.getLogger(logger_name)
+    # if the logger has been initialized, just return it
+    if logger_name in initialized_logger:
+        return logger
+
+    format_str = '%(asctime)s %(levelname)s: %(message)s'
+    stream_handler = logging.StreamHandler()
+    stream_handler.setFormatter(logging.Formatter(format_str))
+    logger.addHandler(stream_handler)
+    logger.propagate = False
+    rank, _ = get_dist_info()
+    if rank != 0:
+        logger.setLevel('ERROR')
+    elif log_file is not None:
+        logger.setLevel(log_level)
+        # add file handler
+        # file_handler = logging.FileHandler(log_file, 'w')
+        file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
+        file_handler.setFormatter(logging.Formatter(format_str))
+        file_handler.setLevel(log_level)
+        logger.addHandler(file_handler)
+    initialized_logger[logger_name] = True
+    return logger
+
+
+def get_env_info():
+    """Get environment information.
+    Currently, only log the software version.
+    """
+    import torch
+    import torchvision
+
+    from basicsr.version import __version__
+    msg = r"""
+                ____                _       _____  ____
+               / __ ) ____ _ _____ (_)_____/ ___/ / __ \
+              / __  |/ __ `// ___// // ___/\__ \ / /_/ /
+             / /_/ // /_/ /(__  )/ // /__ ___/ // _, _/
+            /_____/ \__,_//____//_/ \___//____//_/ |_|
+     ______                   __   __                 __      __
+    / ____/____   ____   ____/ /  / /   __  __ _____ / /__   / /
+   / / __ / __ \ / __ \ / __  /  / /   / / / // ___// //_/  / /
+  / /_/ // /_/ // /_/ // /_/ /  / /___/ /_/ // /__ / /<    /_/
+  \____/ \____/ \____/ \____/  /_____/\____/ \___//_/|_|  (_)
+    """
+    msg += ('\nVersion Information: '
+            f'\n\tBasicSR: {__version__}'
+            f'\n\tPyTorch: {torch.__version__}'
+            f'\n\tTorchVision: {torchvision.__version__}')
+    return msg
\ No newline at end of file
diff --git a/CodeFormer/basicsr/utils/matlab_functions.py b/CodeFormer/basicsr/utils/matlab_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6ce1004a2c9f8521505c4b5889d3c24a909c70d
--- /dev/null
+++ b/CodeFormer/basicsr/utils/matlab_functions.py
@@ -0,0 +1,347 @@
+import math
+import numpy as np
+import torch
+
+
+def cubic(x):
+    """cubic function used for calculate_weights_indices."""
+    absx = torch.abs(x)
+    absx2 = absx**2
+    absx3 = absx**3
+    return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+        (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
+                                                                                     (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+    """Calculate weights and indices, used for imresize function.
+
+    Args:
+        in_length (int): Input length.
+        out_length (int): Output length.
+        scale (float): Scale factor.
+        kernel_width (int): Kernel width.
+        antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+    """
+
+    if (scale < 1) and antialiasing:
+        # Use a modified kernel (larger kernel width) to simultaneously
+        # interpolate and antialias
+        kernel_width = kernel_width / scale
+
+    # Output-space coordinates
+    x = torch.linspace(1, out_length, out_length)
+
+    # Input-space coordinates. Calculate the inverse mapping such that 0.5
+    # in output space maps to 0.5 in input space, and 0.5 + scale in output
+    # space maps to 1.5 in input space.
+    u = x / scale + 0.5 * (1 - 1 / scale)
+
+    # What is the left-most pixel that can be involved in the computation?
+    left = torch.floor(u - kernel_width / 2)
+
+    # What is the maximum number of pixels that can be involved in the
+    # computation?  Note: it's OK to use an extra pixel here; if the
+    # corresponding weights are all zero, it will be eliminated at the end
+    # of this function.
+    p = math.ceil(kernel_width) + 2
+
+    # The indices of the input pixels involved in computing the k-th output
+    # pixel are in row k of the indices matrix.
+    indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
+        out_length, p)
+
+    # The weights used to compute the k-th output pixel are in row k of the
+    # weights matrix.
+    distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+    # apply cubic kernel
+    if (scale < 1) and antialiasing:
+        weights = scale * cubic(distance_to_center * scale)
+    else:
+        weights = cubic(distance_to_center)
+
+    # Normalize the weights matrix so that each row sums to 1.
+    weights_sum = torch.sum(weights, 1).view(out_length, 1)
+    weights = weights / weights_sum.expand(out_length, p)
+
+    # If a column in weights is all zero, get rid of it. only consider the
+    # first and last column.
+    weights_zero_tmp = torch.sum((weights == 0), 0)
+    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 1, p - 2)
+        weights = weights.narrow(1, 1, p - 2)
+    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 0, p - 2)
+        weights = weights.narrow(1, 0, p - 2)
+    weights = weights.contiguous()
+    indices = indices.contiguous()
+    sym_len_s = -indices.min() + 1
+    sym_len_e = indices.max() - in_length
+    indices = indices + sym_len_s - 1
+    return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+    """imresize function same as MATLAB.
+
+    It now only supports bicubic.
+    The same scale applies for both height and width.
+
+    Args:
+        img (Tensor | Numpy array):
+            Tensor: Input image with shape (c, h, w), [0, 1] range.
+            Numpy: Input image with shape (h, w, c), [0, 1] range.
+        scale (float): Scale factor. The same scale applies for both height
+            and width.
+        antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+            Default: True.
+
+    Returns:
+        Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+    """
+    if type(img).__module__ == np.__name__:  # numpy type
+        numpy_type = True
+        img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+    else:
+        numpy_type = False
+
+    in_c, in_h, in_w = img.size()
+    out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # get weights and indices
+    weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
+                                                                             antialiasing)
+    weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
+                                                                             antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+    img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+    sym_patch = img[:, :sym_len_hs, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+    sym_patch = img[:, -sym_len_he:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(in_c, out_h, in_w)
+    kernel_width = weights_h.size(1)
+    for i in range(out_h):
+        idx = int(indices_h[i][0])
+        for j in range(in_c):
+            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+    out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+    sym_patch = out_1[:, :, :sym_len_ws]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, :, -sym_len_we:]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(in_c, out_h, out_w)
+    kernel_width = weights_w.size(1)
+    for i in range(out_w):
+        idx = int(indices_w[i][0])
+        for j in range(in_c):
+            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
+
+    if numpy_type:
+        out_2 = out_2.numpy().transpose(1, 2, 0)
+    return out_2
+
+
+def rgb2ycbcr(img, y_only=False):
+    """Convert a RGB image to YCbCr image.
+
+    This function produces the same results as Matlab's `rgb2ycbcr` function.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+        y_only (bool): Whether to only return Y channel. Default: False.
+
+    Returns:
+        ndarray: The converted YCbCr image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img)
+    if y_only:
+        out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+    else:
+        out_img = np.matmul(
+            img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+    """Convert a BGR image to YCbCr image.
+
+    The bgr version of rgb2ycbcr.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+        y_only (bool): Whether to only return Y channel. Default: False.
+
+    Returns:
+        ndarray: The converted YCbCr image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img)
+    if y_only:
+        out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+    else:
+        out_img = np.matmul(
+            img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def ycbcr2rgb(img):
+    """Convert a YCbCr image to RGB image.
+
+    This function produces the same results as Matlab's ycbcr2rgb function.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        ndarray: The converted RGB image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img) * 255
+    out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+                              [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]  # noqa: E126
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def ycbcr2bgr(img):
+    """Convert a YCbCr image to BGR image.
+
+    The bgr version of ycbcr2rgb.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        ndarray: The converted BGR image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img) * 255
+    out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
+                              [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921]  # noqa: E126
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def _convert_input_type_range(img):
+    """Convert the type and range of the input image.
+
+    It converts the input image to np.float32 type and range of [0, 1].
+    It is mainly used for pre-processing the input image in colorspace
+    convertion functions such as rgb2ycbcr and ycbcr2rgb.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        (ndarray): The converted image with type of np.float32 and range of
+            [0, 1].
+    """
+    img_type = img.dtype
+    img = img.astype(np.float32)
+    if img_type == np.float32:
+        pass
+    elif img_type == np.uint8:
+        img /= 255.
+    else:
+        raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
+    return img
+
+
+def _convert_output_type_range(img, dst_type):
+    """Convert the type and range of the image according to dst_type.
+
+    It converts the image to desired type and range. If `dst_type` is np.uint8,
+    images will be converted to np.uint8 type with range [0, 255]. If
+    `dst_type` is np.float32, it converts the image to np.float32 type with
+    range [0, 1].
+    It is mainly used for post-processing images in colorspace convertion
+    functions such as rgb2ycbcr and ycbcr2rgb.
+
+    Args:
+        img (ndarray): The image to be converted with np.float32 type and
+            range [0, 255].
+        dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+            converts the image to np.uint8 type with range [0, 255]. If
+            dst_type is np.float32, it converts the image to np.float32 type
+            with range [0, 1].
+
+    Returns:
+        (ndarray): The converted image with desired type and range.
+    """
+    if dst_type not in (np.uint8, np.float32):
+        raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
+    if dst_type == np.uint8:
+        img = img.round()
+    else:
+        img /= 255.
+    return img.astype(dst_type)
diff --git a/CodeFormer/basicsr/utils/misc.py b/CodeFormer/basicsr/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b444ff3b950e38f43a5451d1330ff1b65951a9e
--- /dev/null
+++ b/CodeFormer/basicsr/utils/misc.py
@@ -0,0 +1,134 @@
+import numpy as np
+import os
+import random
+import time
+import torch
+from os import path as osp
+
+from .dist_util import master_only
+from .logger import get_root_logger
+
+
+def set_random_seed(seed):
+    """Set random seeds."""
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+    return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def mkdir_and_rename(path):
+    """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+    Args:
+        path (str): Folder path.
+    """
+    if osp.exists(path):
+        new_name = path + '_archived_' + get_time_str()
+        print(f'Path already exists. Rename it to {new_name}', flush=True)
+        os.rename(path, new_name)
+    os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+    """Make dirs for experiments."""
+    path_opt = opt['path'].copy()
+    if opt['is_train']:
+        mkdir_and_rename(path_opt.pop('experiments_root'))
+    else:
+        mkdir_and_rename(path_opt.pop('results_root'))
+    for key, path in path_opt.items():
+        if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
+            os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+    """Scan a directory to find the interested files.
+
+    Args:
+        dir_path (str): Path of the directory.
+        suffix (str | tuple(str), optional): File suffix that we are
+            interested in. Default: None.
+        recursive (bool, optional): If set to True, recursively scan the
+            directory. Default: False.
+        full_path (bool, optional): If set to True, include the dir_path.
+            Default: False.
+
+    Returns:
+        A generator for all the interested files with relative pathes.
+    """
+
+    if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+        raise TypeError('"suffix" must be a string or tuple of strings')
+
+    root = dir_path
+
+    def _scandir(dir_path, suffix, recursive):
+        for entry in os.scandir(dir_path):
+            if not entry.name.startswith('.') and entry.is_file():
+                if full_path:
+                    return_path = entry.path
+                else:
+                    return_path = osp.relpath(entry.path, root)
+
+                if suffix is None:
+                    yield return_path
+                elif return_path.endswith(suffix):
+                    yield return_path
+            else:
+                if recursive:
+                    yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+                else:
+                    continue
+
+    return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+    """Check resume states and pretrain_network paths.
+
+    Args:
+        opt (dict): Options.
+        resume_iter (int): Resume iteration.
+    """
+    logger = get_root_logger()
+    if opt['path']['resume_state']:
+        # get all the networks
+        networks = [key for key in opt.keys() if key.startswith('network_')]
+        flag_pretrain = False
+        for network in networks:
+            if opt['path'].get(f'pretrain_{network}') is not None:
+                flag_pretrain = True
+        if flag_pretrain:
+            logger.warning('pretrain_network path will be ignored during resuming.')
+        # set pretrained model paths
+        for network in networks:
+            name = f'pretrain_{network}'
+            basename = network.replace('network_', '')
+            if opt['path'].get('ignore_resume_networks') is None or (basename
+                                                                     not in opt['path']['ignore_resume_networks']):
+                opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
+                logger.info(f"Set {name} to {opt['path'][name]}")
+
+
+def sizeof_fmt(size, suffix='B'):
+    """Get human readable file size.
+
+    Args:
+        size (int): File size.
+        suffix (str): Suffix. Default: 'B'.
+
+    Return:
+        str: Formated file siz.
+    """
+    for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+        if abs(size) < 1024.0:
+            return f'{size:3.1f} {unit}{suffix}'
+        size /= 1024.0
+    return f'{size:3.1f} Y{suffix}'
diff --git a/CodeFormer/basicsr/utils/options.py b/CodeFormer/basicsr/utils/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..db490e4aa52e26fde31959fd74c2cef3af2ecf76
--- /dev/null
+++ b/CodeFormer/basicsr/utils/options.py
@@ -0,0 +1,108 @@
+import yaml
+import time
+from collections import OrderedDict
+from os import path as osp
+from basicsr.utils.misc import get_time_str
+
+def ordered_yaml():
+    """Support OrderedDict for yaml.
+
+    Returns:
+        yaml Loader and Dumper.
+    """
+    try:
+        from yaml import CDumper as Dumper
+        from yaml import CLoader as Loader
+    except ImportError:
+        from yaml import Dumper, Loader
+
+    _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+    def dict_representer(dumper, data):
+        return dumper.represent_dict(data.items())
+
+    def dict_constructor(loader, node):
+        return OrderedDict(loader.construct_pairs(node))
+
+    Dumper.add_representer(OrderedDict, dict_representer)
+    Loader.add_constructor(_mapping_tag, dict_constructor)
+    return Loader, Dumper
+
+
+def parse(opt_path, root_path, is_train=True):
+    """Parse option file.
+
+    Args:
+        opt_path (str): Option file path.
+        is_train (str): Indicate whether in training or not. Default: True.
+
+    Returns:
+        (dict): Options.
+    """
+    with open(opt_path, mode='r') as f:
+        Loader, _ = ordered_yaml()
+        opt = yaml.load(f, Loader=Loader)
+
+    opt['is_train'] = is_train
+
+    # opt['name'] = f"{get_time_str()}_{opt['name']}"
+    if opt['path'].get('resume_state', None): # Shangchen added
+        resume_state_path = opt['path'].get('resume_state')
+        opt['name'] = resume_state_path.split("/")[-3]
+    else:
+        opt['name'] = f"{get_time_str()}_{opt['name']}"
+
+
+    # datasets
+    for phase, dataset in opt['datasets'].items():
+        # for several datasets, e.g., test_1, test_2
+        phase = phase.split('_')[0]
+        dataset['phase'] = phase
+        if 'scale' in opt:
+            dataset['scale'] = opt['scale']
+        if dataset.get('dataroot_gt') is not None:
+            dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+        if dataset.get('dataroot_lq') is not None:
+            dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+    # paths
+    for key, val in opt['path'].items():
+        if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+            opt['path'][key] = osp.expanduser(val)
+
+    if is_train:
+        experiments_root = osp.join(root_path, 'experiments', opt['name'])
+        opt['path']['experiments_root'] = experiments_root
+        opt['path']['models'] = osp.join(experiments_root, 'models')
+        opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
+        opt['path']['log'] = experiments_root
+        opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
+
+    else:  # test
+        results_root = osp.join(root_path, 'results', opt['name'])
+        opt['path']['results_root'] = results_root
+        opt['path']['log'] = results_root
+        opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+    return opt
+
+
+def dict2str(opt, indent_level=1):
+    """dict to string for printing options.
+
+    Args:
+        opt (dict): Option dict.
+        indent_level (int): Indent level. Default: 1.
+
+    Return:
+        (str): Option string for printing.
+    """
+    msg = '\n'
+    for k, v in opt.items():
+        if isinstance(v, dict):
+            msg += ' ' * (indent_level * 2) + k + ':['
+            msg += dict2str(v, indent_level + 1)
+            msg += ' ' * (indent_level * 2) + ']\n'
+        else:
+            msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+    return msg
diff --git a/CodeFormer/basicsr/utils/realesrgan_utils.py b/CodeFormer/basicsr/utils/realesrgan_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff94523b7ddd61f0b72280950fd36e1b8133bf4c
--- /dev/null
+++ b/CodeFormer/basicsr/utils/realesrgan_utils.py
@@ -0,0 +1,296 @@
+import cv2
+import math
+import numpy as np
+import os
+import queue
+import threading
+import torch
+from basicsr.utils.download_util import load_file_from_url
+from torch.nn import functional as F
+
+# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class RealESRGANer():
+    """A helper class for upsampling images with RealESRGAN.
+
+    Args:
+        scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+        model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+        model (nn.Module): The defined network. Default: None.
+        tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+            input images into tiles, and then process each of them. Finally, they will be merged into one image.
+            0 denotes for do not use tile. Default: 0.
+        tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+        pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+        half (float): Whether to use half precision during inference. Default: False.
+    """
+
+    def __init__(self,
+                 scale,
+                 model_path,
+                 model=None,
+                 tile=0,
+                 tile_pad=10,
+                 pre_pad=10,
+                 half=False,
+                 device=None,
+                 gpu_id=None):
+        self.scale = scale
+        self.tile_size = tile
+        self.tile_pad = tile_pad
+        self.pre_pad = pre_pad
+        self.mod_scale = None
+        self.half = half
+
+        # initialize model
+        if gpu_id:
+            self.device = torch.device(
+                f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+        else:
+            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+        # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
+        if model_path.startswith('https://'):
+            model_path = load_file_from_url(
+                url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
+        loadnet = torch.load(model_path, map_location=torch.device('cpu'))
+        # prefer to use params_ema
+        if 'params_ema' in loadnet:
+            keyname = 'params_ema'
+        else:
+            keyname = 'params'
+        model.load_state_dict(loadnet[keyname], strict=True)
+        model.eval()
+        self.model = model.to(self.device)
+        if self.half:
+            self.model = self.model.half()
+
+    def pre_process(self, img):
+        """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+        """
+        img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+        self.img = img.unsqueeze(0).to(self.device)
+        if self.half:
+            self.img = self.img.half()
+
+        # pre_pad
+        if self.pre_pad != 0:
+            self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
+        # mod pad for divisible borders
+        if self.scale == 2:
+            self.mod_scale = 2
+        elif self.scale == 1:
+            self.mod_scale = 4
+        if self.mod_scale is not None:
+            self.mod_pad_h, self.mod_pad_w = 0, 0
+            _, _, h, w = self.img.size()
+            if (h % self.mod_scale != 0):
+                self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
+            if (w % self.mod_scale != 0):
+                self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
+            self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
+
+    def process(self):
+        # model inference
+        self.output = self.model(self.img)
+
+    def tile_process(self):
+        """It will first crop input images to tiles, and then process each tile.
+        Finally, all the processed tiles are merged into one images.
+
+        Modified from: https://github.com/ata4/esrgan-launcher
+        """
+        batch, channel, height, width = self.img.shape
+        output_height = height * self.scale
+        output_width = width * self.scale
+        output_shape = (batch, channel, output_height, output_width)
+
+        # start with black image
+        self.output = self.img.new_zeros(output_shape)
+        tiles_x = math.ceil(width / self.tile_size)
+        tiles_y = math.ceil(height / self.tile_size)
+
+        # loop over all tiles
+        for y in range(tiles_y):
+            for x in range(tiles_x):
+                # extract tile from input image
+                ofs_x = x * self.tile_size
+                ofs_y = y * self.tile_size
+                # input tile area on total image
+                input_start_x = ofs_x
+                input_end_x = min(ofs_x + self.tile_size, width)
+                input_start_y = ofs_y
+                input_end_y = min(ofs_y + self.tile_size, height)
+
+                # input tile area on total image with padding
+                input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+                input_end_x_pad = min(input_end_x + self.tile_pad, width)
+                input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+                input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+                # input tile dimensions
+                input_tile_width = input_end_x - input_start_x
+                input_tile_height = input_end_y - input_start_y
+                tile_idx = y * tiles_x + x + 1
+                input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+                # upscale tile
+                try:
+                    with torch.no_grad():
+                        output_tile = self.model(input_tile)
+                except RuntimeError as error:
+                    print('Error', error)
+                # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+                # output tile area on total image
+                output_start_x = input_start_x * self.scale
+                output_end_x = input_end_x * self.scale
+                output_start_y = input_start_y * self.scale
+                output_end_y = input_end_y * self.scale
+
+                # output tile area without padding
+                output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+                output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+                output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+                output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+                # put tile into output image
+                self.output[:, :, output_start_y:output_end_y,
+                            output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
+                                                                       output_start_x_tile:output_end_x_tile]
+
+    def post_process(self):
+        # remove extra pad
+        if self.mod_scale is not None:
+            _, _, h, w = self.output.size()
+            self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
+        # remove prepad
+        if self.pre_pad != 0:
+            _, _, h, w = self.output.size()
+            self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
+        return self.output
+
+    @torch.no_grad()
+    def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
+        h_input, w_input = img.shape[0:2]
+        # img: numpy
+        img = img.astype(np.float32)
+        if np.max(img) > 256:  # 16-bit image
+            max_range = 65535
+            print('\tInput is a 16-bit image')
+        else:
+            max_range = 255
+        img = img / max_range
+        if len(img.shape) == 2:  # gray image
+            img_mode = 'L'
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+        elif img.shape[2] == 4:  # RGBA image with alpha channel
+            img_mode = 'RGBA'
+            alpha = img[:, :, 3]
+            img = img[:, :, 0:3]
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            if alpha_upsampler == 'realesrgan':
+                alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+        else:
+            img_mode = 'RGB'
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+        # ------------------- process image (without the alpha channel) ------------------- #
+        with torch.no_grad():
+            self.pre_process(img)
+            if self.tile_size > 0:
+                self.tile_process()
+            else:
+                self.process()
+            output_img_t = self.post_process()
+            output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+            output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+            if img_mode == 'L':
+                output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+        del output_img_t
+        torch.cuda.empty_cache()        
+
+        # ------------------- process the alpha channel if necessary ------------------- #
+        if img_mode == 'RGBA':
+            if alpha_upsampler == 'realesrgan':
+                self.pre_process(alpha)
+                if self.tile_size > 0:
+                    self.tile_process()
+                else:
+                    self.process()
+                output_alpha = self.post_process()
+                output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+                output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+                output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+            else:  # use the cv2 resize for alpha channel
+                h, w = alpha.shape[0:2]
+                output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+            # merge the alpha channel
+            output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+            output_img[:, :, 3] = output_alpha
+
+        # ------------------------------ return ------------------------------ #
+        if max_range == 65535:  # 16-bit image
+            output = (output_img * 65535.0).round().astype(np.uint16)
+        else:
+            output = (output_img * 255.0).round().astype(np.uint8)
+
+        if outscale is not None and outscale != float(self.scale):
+            output = cv2.resize(
+                output, (
+                    int(w_input * outscale),
+                    int(h_input * outscale),
+                ), interpolation=cv2.INTER_LANCZOS4)
+
+        return output, img_mode
+
+
+class PrefetchReader(threading.Thread):
+    """Prefetch images.
+
+    Args:
+        img_list (list[str]): A image list of image paths to be read.
+        num_prefetch_queue (int): Number of prefetch queue.
+    """
+
+    def __init__(self, img_list, num_prefetch_queue):
+        super().__init__()
+        self.que = queue.Queue(num_prefetch_queue)
+        self.img_list = img_list
+
+    def run(self):
+        for img_path in self.img_list:
+            img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+            self.que.put(img)
+
+        self.que.put(None)
+
+    def __next__(self):
+        next_item = self.que.get()
+        if next_item is None:
+            raise StopIteration
+        return next_item
+
+    def __iter__(self):
+        return self
+
+
+class IOConsumer(threading.Thread):
+
+    def __init__(self, opt, que, qid):
+        super().__init__()
+        self._queue = que
+        self.qid = qid
+        self.opt = opt
+
+    def run(self):
+        while True:
+            msg = self._queue.get()
+            if isinstance(msg, str) and msg == 'quit':
+                break
+
+            output = msg['output']
+            save_path = msg['save_path']
+            cv2.imwrite(save_path, output)
+        print(f'IO worker {self.qid} is done.')
\ No newline at end of file
diff --git a/CodeFormer/basicsr/utils/registry.py b/CodeFormer/basicsr/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..655753b3b9cbd0cfe73fe93a77cf1fcc3db6d827
--- /dev/null
+++ b/CodeFormer/basicsr/utils/registry.py
@@ -0,0 +1,82 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py  # noqa: E501
+
+
+class Registry():
+    """
+    The registry that provides name -> object mapping, to support third-party
+    users' custom modules.
+
+    To create a registry (e.g. a backbone registry):
+
+    .. code-block:: python
+
+        BACKBONE_REGISTRY = Registry('BACKBONE')
+
+    To register an object:
+
+    .. code-block:: python
+
+        @BACKBONE_REGISTRY.register()
+        class MyBackbone():
+            ...
+
+    Or:
+
+    .. code-block:: python
+
+        BACKBONE_REGISTRY.register(MyBackbone)
+    """
+
+    def __init__(self, name):
+        """
+        Args:
+            name (str): the name of this registry
+        """
+        self._name = name
+        self._obj_map = {}
+
+    def _do_register(self, name, obj):
+        assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
+                                             f"in '{self._name}' registry!")
+        self._obj_map[name] = obj
+
+    def register(self, obj=None):
+        """
+        Register the given object under the the name `obj.__name__`.
+        Can be used as either a decorator or not.
+        See docstring of this class for usage.
+        """
+        if obj is None:
+            # used as a decorator
+            def deco(func_or_class):
+                name = func_or_class.__name__
+                self._do_register(name, func_or_class)
+                return func_or_class
+
+            return deco
+
+        # used as a function call
+        name = obj.__name__
+        self._do_register(name, obj)
+
+    def get(self, name):
+        ret = self._obj_map.get(name)
+        if ret is None:
+            raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+        return ret
+
+    def __contains__(self, name):
+        return name in self._obj_map
+
+    def __iter__(self):
+        return iter(self._obj_map.items())
+
+    def keys(self):
+        return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry('dataset')
+ARCH_REGISTRY = Registry('arch')
+MODEL_REGISTRY = Registry('model')
+LOSS_REGISTRY = Registry('loss')
+METRIC_REGISTRY = Registry('metric')
diff --git a/CodeFormer/facelib/.DS_Store b/CodeFormer/facelib/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..110c1e9d6bc50d68182815182e2124f558185964
Binary files /dev/null and b/CodeFormer/facelib/.DS_Store differ
diff --git a/CodeFormer/facelib/detection/.DS_Store b/CodeFormer/facelib/detection/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..751619a512b36f3d1908c037b85045d2b3d75d62
Binary files /dev/null and b/CodeFormer/facelib/detection/.DS_Store differ
diff --git a/CodeFormer/facelib/detection/__init__.py b/CodeFormer/facelib/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..296262d4e2e29eaa2afba7bda1f0399d77da24f6
--- /dev/null
+++ b/CodeFormer/facelib/detection/__init__.py
@@ -0,0 +1,100 @@
+import os
+import torch
+from torch import nn
+from copy import deepcopy
+
+from facelib.utils import load_file_from_url
+from facelib.utils import download_pretrained_models
+from facelib.detection.yolov5face.models.common import Conv
+
+from .retinaface.retinaface import RetinaFace
+from .yolov5face.face_detector import YoloDetector
+
+
+def init_detection_model(model_name, half=False, device='cuda'):
+    if 'retinaface' in model_name:
+        model = init_retinaface_model(model_name, half, device)
+    elif 'YOLOv5' in model_name:
+        model = init_yolov5face_model(model_name, device)
+    else:
+        raise NotImplementedError(f'{model_name} is not implemented.')
+
+    return model
+
+
+def init_retinaface_model(model_name, half=False, device='cuda'):
+    if model_name == 'retinaface_resnet50':
+        model = RetinaFace(network_name='resnet50', half=half)
+        model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
+    elif model_name == 'retinaface_mobile0.25':
+        model = RetinaFace(network_name='mobile0.25', half=half)
+        model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
+    else:
+        raise NotImplementedError(f'{model_name} is not implemented.')
+
+    model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+    load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+    # remove unnecessary 'module.'
+    for k, v in deepcopy(load_net).items():
+        if k.startswith('module.'):
+            load_net[k[7:]] = v
+            load_net.pop(k)
+    model.load_state_dict(load_net, strict=True)
+    model.eval()
+    model = model.to(device)
+
+    return model
+
+
+def init_yolov5face_model(model_name, device='cuda'):
+    if model_name == 'YOLOv5l':
+        model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
+        model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth'
+    elif model_name == 'YOLOv5n':
+        model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
+        model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth'
+    else:
+        raise NotImplementedError(f'{model_name} is not implemented.')
+    
+    model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+    load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+    model.detector.load_state_dict(load_net, strict=True)
+    model.detector.eval()
+    model.detector = model.detector.to(device).float()
+
+    for m in model.detector.modules():
+        if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+            m.inplace = True  # pytorch 1.7.0 compatibility
+        elif isinstance(m, Conv):
+            m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
+
+    return model
+
+
+# Download from Google Drive
+# def init_yolov5face_model(model_name, device='cuda'):
+#     if model_name == 'YOLOv5l':
+#         model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
+#         f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
+#     elif model_name == 'YOLOv5n':
+#         model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
+#         f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
+#     else:
+#         raise NotImplementedError(f'{model_name} is not implemented.')
+
+#     model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
+#     if not os.path.exists(model_path):
+#         download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
+
+#     load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+#     model.detector.load_state_dict(load_net, strict=True)
+#     model.detector.eval()
+#     model.detector = model.detector.to(device).float()
+
+#     for m in model.detector.modules():
+#         if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+#             m.inplace = True  # pytorch 1.7.0 compatibility
+#         elif isinstance(m, Conv):
+#             m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
+
+#     return model
\ No newline at end of file
diff --git a/CodeFormer/facelib/detection/align_trans.py b/CodeFormer/facelib/detection/align_trans.py
new file mode 100644
index 0000000000000000000000000000000000000000..07f1eb365462c2ec5bbac6d1854c786b6fd6be90
--- /dev/null
+++ b/CodeFormer/facelib/detection/align_trans.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+
+from .matlab_cp2tform import get_similarity_transform_for_cv2
+
+# reference facial points, a list of coordinates (x,y)
+REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
+                           [33.54930115, 92.3655014], [62.72990036, 92.20410156]]
+
+DEFAULT_CROP_SIZE = (96, 112)
+
+
+class FaceWarpException(Exception):
+
+    def __str__(self):
+        return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
+    """
+    Function:
+    ----------
+        get reference 5 key points according to crop settings:
+        0. Set default crop_size:
+            if default_square:
+                crop_size = (112, 112)
+            else:
+                crop_size = (96, 112)
+        1. Pad the crop_size by inner_padding_factor in each side;
+        2. Resize crop_size into (output_size - outer_padding*2),
+            pad into output_size with outer_padding;
+        3. Output reference_5point;
+    Parameters:
+    ----------
+        @output_size: (w, h) or None
+            size of aligned face image
+        @inner_padding_factor: (w_factor, h_factor)
+            padding factor for inner (w, h)
+        @outer_padding: (w_pad, h_pad)
+            each row is a pair of coordinates (x, y)
+        @default_square: True or False
+            if True:
+                default crop_size = (112, 112)
+            else:
+                default crop_size = (96, 112);
+        !!! make sure, if output_size is not None:
+                (output_size - outer_padding)
+                = some_scale * (default crop_size * (1.0 +
+                inner_padding_factor))
+    Returns:
+    ----------
+        @reference_5point: 5x2 np.array
+            each row is a pair of transformed coordinates (x, y)
+    """
+
+    tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
+    tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
+
+    # 0) make the inner region a square
+    if default_square:
+        size_diff = max(tmp_crop_size) - tmp_crop_size
+        tmp_5pts += size_diff / 2
+        tmp_crop_size += size_diff
+
+    if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
+
+        return tmp_5pts
+
+    if (inner_padding_factor == 0 and outer_padding == (0, 0)):
+        if output_size is None:
+            return tmp_5pts
+        else:
+            raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
+
+    # check output size
+    if not (0 <= inner_padding_factor <= 1.0):
+        raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
+
+    if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
+        output_size = tmp_crop_size * \
+            (1 + inner_padding_factor * 2).astype(np.int32)
+        output_size += np.array(outer_padding)
+    if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
+        raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
+
+    # 1) pad the inner region according inner_padding_factor
+    if inner_padding_factor > 0:
+        size_diff = tmp_crop_size * inner_padding_factor * 2
+        tmp_5pts += size_diff / 2
+        tmp_crop_size += np.round(size_diff).astype(np.int32)
+
+    # 2) resize the padded inner region
+    size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
+
+    if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
+        raise FaceWarpException('Must have (output_size - outer_padding)'
+                                '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
+
+    scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
+    tmp_5pts = tmp_5pts * scale_factor
+    #    size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
+    #    tmp_5pts = tmp_5pts + size_diff / 2
+    tmp_crop_size = size_bf_outer_pad
+
+    # 3) add outer_padding to make output_size
+    reference_5point = tmp_5pts + np.array(outer_padding)
+    tmp_crop_size = output_size
+
+    return reference_5point
+
+
+def get_affine_transform_matrix(src_pts, dst_pts):
+    """
+    Function:
+    ----------
+        get affine transform matrix 'tfm' from src_pts to dst_pts
+    Parameters:
+    ----------
+        @src_pts: Kx2 np.array
+            source points matrix, each row is a pair of coordinates (x, y)
+        @dst_pts: Kx2 np.array
+            destination points matrix, each row is a pair of coordinates (x, y)
+    Returns:
+    ----------
+        @tfm: 2x3 np.array
+            transform matrix from src_pts to dst_pts
+    """
+
+    tfm = np.float32([[1, 0, 0], [0, 1, 0]])
+    n_pts = src_pts.shape[0]
+    ones = np.ones((n_pts, 1), src_pts.dtype)
+    src_pts_ = np.hstack([src_pts, ones])
+    dst_pts_ = np.hstack([dst_pts, ones])
+
+    A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
+
+    if rank == 3:
+        tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
+    elif rank == 2:
+        tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
+
+    return tfm
+
+
+def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
+    """
+    Function:
+    ----------
+        apply affine transform 'trans' to uv
+    Parameters:
+    ----------
+        @src_img: 3x3 np.array
+            input image
+        @facial_pts: could be
+            1)a list of K coordinates (x,y)
+        or
+            2) Kx2 or 2xK np.array
+            each row or col is a pair of coordinates (x, y)
+        @reference_pts: could be
+            1) a list of K coordinates (x,y)
+        or
+            2) Kx2 or 2xK np.array
+            each row or col is a pair of coordinates (x, y)
+        or
+            3) None
+            if None, use default reference facial points
+        @crop_size: (w, h)
+            output face image size
+        @align_type: transform type, could be one of
+            1) 'similarity': use similarity transform
+            2) 'cv2_affine': use the first 3 points to do affine transform,
+                    by calling cv2.getAffineTransform()
+            3) 'affine': use all points to do affine transform
+    Returns:
+    ----------
+        @face_img: output face image with size (w, h) = @crop_size
+    """
+
+    if reference_pts is None:
+        if crop_size[0] == 96 and crop_size[1] == 112:
+            reference_pts = REFERENCE_FACIAL_POINTS
+        else:
+            default_square = False
+            inner_padding_factor = 0
+            outer_padding = (0, 0)
+            output_size = crop_size
+
+            reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
+                                                        default_square)
+
+    ref_pts = np.float32(reference_pts)
+    ref_pts_shp = ref_pts.shape
+    if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
+        raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
+
+    if ref_pts_shp[0] == 2:
+        ref_pts = ref_pts.T
+
+    src_pts = np.float32(facial_pts)
+    src_pts_shp = src_pts.shape
+    if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
+        raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
+
+    if src_pts_shp[0] == 2:
+        src_pts = src_pts.T
+
+    if src_pts.shape != ref_pts.shape:
+        raise FaceWarpException('facial_pts and reference_pts must have the same shape')
+
+    if align_type == 'cv2_affine':
+        tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
+    elif align_type == 'affine':
+        tfm = get_affine_transform_matrix(src_pts, ref_pts)
+    else:
+        tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
+
+    face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
+
+    return face_img
diff --git a/CodeFormer/facelib/detection/matlab_cp2tform.py b/CodeFormer/facelib/detection/matlab_cp2tform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2a8b54a91709c71437e15c68d3be9a9b0a20a34
--- /dev/null
+++ b/CodeFormer/facelib/detection/matlab_cp2tform.py
@@ -0,0 +1,317 @@
+import numpy as np
+from numpy.linalg import inv, lstsq
+from numpy.linalg import matrix_rank as rank
+from numpy.linalg import norm
+
+
+class MatlabCp2tormException(Exception):
+
+    def __str__(self):
+        return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def tformfwd(trans, uv):
+    """
+    Function:
+    ----------
+        apply affine transform 'trans' to uv
+
+    Parameters:
+    ----------
+        @trans: 3x3 np.array
+            transform matrix
+        @uv: Kx2 np.array
+            each row is a pair of coordinates (x, y)
+
+    Returns:
+    ----------
+        @xy: Kx2 np.array
+            each row is a pair of transformed coordinates (x, y)
+    """
+    uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
+    xy = np.dot(uv, trans)
+    xy = xy[:, 0:-1]
+    return xy
+
+
+def tforminv(trans, uv):
+    """
+    Function:
+    ----------
+        apply the inverse of affine transform 'trans' to uv
+
+    Parameters:
+    ----------
+        @trans: 3x3 np.array
+            transform matrix
+        @uv: Kx2 np.array
+            each row is a pair of coordinates (x, y)
+
+    Returns:
+    ----------
+        @xy: Kx2 np.array
+            each row is a pair of inverse-transformed coordinates (x, y)
+    """
+    Tinv = inv(trans)
+    xy = tformfwd(Tinv, uv)
+    return xy
+
+
+def findNonreflectiveSimilarity(uv, xy, options=None):
+    options = {'K': 2}
+
+    K = options['K']
+    M = xy.shape[0]
+    x = xy[:, 0].reshape((-1, 1))  # use reshape to keep a column vector
+    y = xy[:, 1].reshape((-1, 1))  # use reshape to keep a column vector
+
+    tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
+    tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
+    X = np.vstack((tmp1, tmp2))
+
+    u = uv[:, 0].reshape((-1, 1))  # use reshape to keep a column vector
+    v = uv[:, 1].reshape((-1, 1))  # use reshape to keep a column vector
+    U = np.vstack((u, v))
+
+    # We know that X * r = U
+    if rank(X) >= 2 * K:
+        r, _, _, _ = lstsq(X, U, rcond=-1)
+        r = np.squeeze(r)
+    else:
+        raise Exception('cp2tform:twoUniquePointsReq')
+    sc = r[0]
+    ss = r[1]
+    tx = r[2]
+    ty = r[3]
+
+    Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
+    T = inv(Tinv)
+    T[:, 2] = np.array([0, 0, 1])
+
+    return T, Tinv
+
+
+def findSimilarity(uv, xy, options=None):
+    options = {'K': 2}
+
+    #    uv = np.array(uv)
+    #    xy = np.array(xy)
+
+    # Solve for trans1
+    trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
+
+    # Solve for trans2
+
+    # manually reflect the xy data across the Y-axis
+    xyR = xy
+    xyR[:, 0] = -1 * xyR[:, 0]
+
+    trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
+
+    # manually reflect the tform to undo the reflection done on xyR
+    TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+    trans2 = np.dot(trans2r, TreflectY)
+
+    # Figure out if trans1 or trans2 is better
+    xy1 = tformfwd(trans1, uv)
+    norm1 = norm(xy1 - xy)
+
+    xy2 = tformfwd(trans2, uv)
+    norm2 = norm(xy2 - xy)
+
+    if norm1 <= norm2:
+        return trans1, trans1_inv
+    else:
+        trans2_inv = inv(trans2)
+        return trans2, trans2_inv
+
+
+def get_similarity_transform(src_pts, dst_pts, reflective=True):
+    """
+    Function:
+    ----------
+        Find Similarity Transform Matrix 'trans':
+            u = src_pts[:, 0]
+            v = src_pts[:, 1]
+            x = dst_pts[:, 0]
+            y = dst_pts[:, 1]
+            [x, y, 1] = [u, v, 1] * trans
+
+    Parameters:
+    ----------
+        @src_pts: Kx2 np.array
+            source points, each row is a pair of coordinates (x, y)
+        @dst_pts: Kx2 np.array
+            destination points, each row is a pair of transformed
+            coordinates (x, y)
+        @reflective: True or False
+            if True:
+                use reflective similarity transform
+            else:
+                use non-reflective similarity transform
+
+    Returns:
+    ----------
+       @trans: 3x3 np.array
+            transform matrix from uv to xy
+        trans_inv: 3x3 np.array
+            inverse of trans, transform matrix from xy to uv
+    """
+
+    if reflective:
+        trans, trans_inv = findSimilarity(src_pts, dst_pts)
+    else:
+        trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
+
+    return trans, trans_inv
+
+
+def cvt_tform_mat_for_cv2(trans):
+    """
+    Function:
+    ----------
+        Convert Transform Matrix 'trans' into 'cv2_trans' which could be
+        directly used by cv2.warpAffine():
+            u = src_pts[:, 0]
+            v = src_pts[:, 1]
+            x = dst_pts[:, 0]
+            y = dst_pts[:, 1]
+            [x, y].T = cv_trans * [u, v, 1].T
+
+    Parameters:
+    ----------
+        @trans: 3x3 np.array
+            transform matrix from uv to xy
+
+    Returns:
+    ----------
+        @cv2_trans: 2x3 np.array
+            transform matrix from src_pts to dst_pts, could be directly used
+            for cv2.warpAffine()
+    """
+    cv2_trans = trans[:, 0:2].T
+
+    return cv2_trans
+
+
+def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
+    """
+    Function:
+    ----------
+        Find Similarity Transform Matrix 'cv2_trans' which could be
+        directly used by cv2.warpAffine():
+            u = src_pts[:, 0]
+            v = src_pts[:, 1]
+            x = dst_pts[:, 0]
+            y = dst_pts[:, 1]
+            [x, y].T = cv_trans * [u, v, 1].T
+
+    Parameters:
+    ----------
+        @src_pts: Kx2 np.array
+            source points, each row is a pair of coordinates (x, y)
+        @dst_pts: Kx2 np.array
+            destination points, each row is a pair of transformed
+            coordinates (x, y)
+        reflective: True or False
+            if True:
+                use reflective similarity transform
+            else:
+                use non-reflective similarity transform
+
+    Returns:
+    ----------
+        @cv2_trans: 2x3 np.array
+            transform matrix from src_pts to dst_pts, could be directly used
+            for cv2.warpAffine()
+    """
+    trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
+    cv2_trans = cvt_tform_mat_for_cv2(trans)
+
+    return cv2_trans
+
+
+if __name__ == '__main__':
+    """
+    u = [0, 6, -2]
+    v = [0, 3, 5]
+    x = [-1, 0, 4]
+    y = [-1, -10, 4]
+
+    # In Matlab, run:
+    #
+    #   uv = [u'; v'];
+    #   xy = [x'; y'];
+    #   tform_sim=cp2tform(uv,xy,'similarity');
+    #
+    #   trans = tform_sim.tdata.T
+    #   ans =
+    #       -0.0764   -1.6190         0
+    #        1.6190   -0.0764         0
+    #       -3.2156    0.0290    1.0000
+    #   trans_inv = tform_sim.tdata.Tinv
+    #    ans =
+    #
+    #       -0.0291    0.6163         0
+    #       -0.6163   -0.0291         0
+    #       -0.0756    1.9826    1.0000
+    #    xy_m=tformfwd(tform_sim, u,v)
+    #
+    #    xy_m =
+    #
+    #       -3.2156    0.0290
+    #        1.1833   -9.9143
+    #        5.0323    2.8853
+    #    uv_m=tforminv(tform_sim, x,y)
+    #
+    #    uv_m =
+    #
+    #        0.5698    1.3953
+    #        6.0872    2.2733
+    #       -2.6570    4.3314
+    """
+    u = [0, 6, -2]
+    v = [0, 3, 5]
+    x = [-1, 0, 4]
+    y = [-1, -10, 4]
+
+    uv = np.array((u, v)).T
+    xy = np.array((x, y)).T
+
+    print('\n--->uv:')
+    print(uv)
+    print('\n--->xy:')
+    print(xy)
+
+    trans, trans_inv = get_similarity_transform(uv, xy)
+
+    print('\n--->trans matrix:')
+    print(trans)
+
+    print('\n--->trans_inv matrix:')
+    print(trans_inv)
+
+    print('\n---> apply transform to uv')
+    print('\nxy_m = uv_augmented * trans')
+    uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
+    xy_m = np.dot(uv_aug, trans)
+    print(xy_m)
+
+    print('\nxy_m = tformfwd(trans, uv)')
+    xy_m = tformfwd(trans, uv)
+    print(xy_m)
+
+    print('\n---> apply inverse transform to xy')
+    print('\nuv_m = xy_augmented * trans_inv')
+    xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
+    uv_m = np.dot(xy_aug, trans_inv)
+    print(uv_m)
+
+    print('\nuv_m = tformfwd(trans_inv, xy)')
+    uv_m = tformfwd(trans_inv, xy)
+    print(uv_m)
+
+    uv_m = tforminv(trans, xy)
+    print('\nuv_m = tforminv(trans, xy)')
+    print(uv_m)
diff --git a/CodeFormer/facelib/detection/retinaface/retinaface.py b/CodeFormer/facelib/detection/retinaface/retinaface.py
new file mode 100644
index 0000000000000000000000000000000000000000..02593556d88a90232bbe55a062875f4af4520621
--- /dev/null
+++ b/CodeFormer/facelib/detection/retinaface/retinaface.py
@@ -0,0 +1,370 @@
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
+
+from facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face
+from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head
+from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
+                                                 py_cpu_nms)
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def generate_config(network_name):
+
+    cfg_mnet = {
+        'name': 'mobilenet0.25',
+        'min_sizes': [[16, 32], [64, 128], [256, 512]],
+        'steps': [8, 16, 32],
+        'variance': [0.1, 0.2],
+        'clip': False,
+        'loc_weight': 2.0,
+        'gpu_train': True,
+        'batch_size': 32,
+        'ngpu': 1,
+        'epoch': 250,
+        'decay1': 190,
+        'decay2': 220,
+        'image_size': 640,
+        'return_layers': {
+            'stage1': 1,
+            'stage2': 2,
+            'stage3': 3
+        },
+        'in_channel': 32,
+        'out_channel': 64
+    }
+
+    cfg_re50 = {
+        'name': 'Resnet50',
+        'min_sizes': [[16, 32], [64, 128], [256, 512]],
+        'steps': [8, 16, 32],
+        'variance': [0.1, 0.2],
+        'clip': False,
+        'loc_weight': 2.0,
+        'gpu_train': True,
+        'batch_size': 24,
+        'ngpu': 4,
+        'epoch': 100,
+        'decay1': 70,
+        'decay2': 90,
+        'image_size': 840,
+        'return_layers': {
+            'layer2': 1,
+            'layer3': 2,
+            'layer4': 3
+        },
+        'in_channel': 256,
+        'out_channel': 256
+    }
+
+    if network_name == 'mobile0.25':
+        return cfg_mnet
+    elif network_name == 'resnet50':
+        return cfg_re50
+    else:
+        raise NotImplementedError(f'network_name={network_name}')
+
+
+class RetinaFace(nn.Module):
+
+    def __init__(self, network_name='resnet50', half=False, phase='test'):
+        super(RetinaFace, self).__init__()
+        self.half_inference = half
+        cfg = generate_config(network_name)
+        self.backbone = cfg['name']
+
+        self.model_name = f'retinaface_{network_name}'
+        self.cfg = cfg
+        self.phase = phase
+        self.target_size, self.max_size = 1600, 2150
+        self.resize, self.scale, self.scale1 = 1., None, None
+        self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device)
+        self.reference = get_reference_facial_points(default_square=True)
+        # Build network.
+        backbone = None
+        if cfg['name'] == 'mobilenet0.25':
+            backbone = MobileNetV1()
+            self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+        elif cfg['name'] == 'Resnet50':
+            import torchvision.models as models
+            backbone = models.resnet50(pretrained=False)
+            self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+
+        in_channels_stage2 = cfg['in_channel']
+        in_channels_list = [
+            in_channels_stage2 * 2,
+            in_channels_stage2 * 4,
+            in_channels_stage2 * 8,
+        ]
+
+        out_channels = cfg['out_channel']
+        self.fpn = FPN(in_channels_list, out_channels)
+        self.ssh1 = SSH(out_channels, out_channels)
+        self.ssh2 = SSH(out_channels, out_channels)
+        self.ssh3 = SSH(out_channels, out_channels)
+
+        self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
+        self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
+        self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
+
+        self.to(device)
+        self.eval()
+        if self.half_inference:
+            self.half()
+
+    def forward(self, inputs):
+        out = self.body(inputs)
+
+        if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50':
+            out = list(out.values())
+        # FPN
+        fpn = self.fpn(out)
+
+        # SSH
+        feature1 = self.ssh1(fpn[0])
+        feature2 = self.ssh2(fpn[1])
+        feature3 = self.ssh3(fpn[2])
+        features = [feature1, feature2, feature3]
+
+        bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
+        classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
+        tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
+        ldm_regressions = (torch.cat(tmp, dim=1))
+
+        if self.phase == 'train':
+            output = (bbox_regressions, classifications, ldm_regressions)
+        else:
+            output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
+        return output
+
+    def __detect_faces(self, inputs):
+        # get scale
+        height, width = inputs.shape[2:]
+        self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device)
+        tmp = [width, height, width, height, width, height, width, height, width, height]
+        self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
+
+        # forawrd
+        inputs = inputs.to(device)
+        if self.half_inference:
+            inputs = inputs.half()
+        loc, conf, landmarks = self(inputs)
+
+        # get priorbox
+        priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
+        priors = priorbox.forward().to(device)
+
+        return loc, conf, landmarks, priors
+
+    # single image detection
+    def transform(self, image, use_origin_size):
+        # convert to opencv format
+        if isinstance(image, Image.Image):
+            image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        image = image.astype(np.float32)
+
+        # testing scale
+        im_size_min = np.min(image.shape[0:2])
+        im_size_max = np.max(image.shape[0:2])
+        resize = float(self.target_size) / float(im_size_min)
+
+        # prevent bigger axis from being more than max_size
+        if np.round(resize * im_size_max) > self.max_size:
+            resize = float(self.max_size) / float(im_size_max)
+        resize = 1 if use_origin_size else resize
+
+        # resize
+        if resize != 1:
+            image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+
+        # convert to torch.tensor format
+        # image -= (104, 117, 123)
+        image = image.transpose(2, 0, 1)
+        image = torch.from_numpy(image).unsqueeze(0)
+
+        return image, resize
+
+    def detect_faces(
+        self,
+        image,
+        conf_threshold=0.8,
+        nms_threshold=0.4,
+        use_origin_size=True,
+    ):
+        """
+        Params:
+            imgs: BGR image
+        """
+        image, self.resize = self.transform(image, use_origin_size)
+        image = image.to(device)
+        if self.half_inference:
+            image = image.half()
+        image = image - self.mean_tensor
+
+        loc, conf, landmarks, priors = self.__detect_faces(image)
+
+        boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance'])
+        boxes = boxes * self.scale / self.resize
+        boxes = boxes.cpu().numpy()
+
+        scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
+
+        landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance'])
+        landmarks = landmarks * self.scale1 / self.resize
+        landmarks = landmarks.cpu().numpy()
+
+        # ignore low scores
+        inds = np.where(scores > conf_threshold)[0]
+        boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
+
+        # sort
+        order = scores.argsort()[::-1]
+        boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
+
+        # do NMS
+        bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+        keep = py_cpu_nms(bounding_boxes, nms_threshold)
+        bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
+        # self.t['forward_pass'].toc()
+        # print(self.t['forward_pass'].average_time)
+        # import sys
+        # sys.stdout.flush()
+        return np.concatenate((bounding_boxes, landmarks), axis=1)
+
+    def __align_multi(self, image, boxes, landmarks, limit=None):
+
+        if len(boxes) < 1:
+            return [], []
+
+        if limit:
+            boxes = boxes[:limit]
+            landmarks = landmarks[:limit]
+
+        faces = []
+        for landmark in landmarks:
+            facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
+
+            warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112))
+            faces.append(warped_face)
+
+        return np.concatenate((boxes, landmarks), axis=1), faces
+
+    def align_multi(self, img, conf_threshold=0.8, limit=None):
+
+        rlt = self.detect_faces(img, conf_threshold=conf_threshold)
+        boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
+
+        return self.__align_multi(img, boxes, landmarks, limit)
+
+    # batched detection
+    def batched_transform(self, frames, use_origin_size):
+        """
+        Arguments:
+            frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
+                type=np.float32, BGR format).
+            use_origin_size: whether to use origin size.
+        """
+        from_PIL = True if isinstance(frames[0], Image.Image) else False
+
+        # convert to opencv format
+        if from_PIL:
+            frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames]
+            frames = np.asarray(frames, dtype=np.float32)
+
+        # testing scale
+        im_size_min = np.min(frames[0].shape[0:2])
+        im_size_max = np.max(frames[0].shape[0:2])
+        resize = float(self.target_size) / float(im_size_min)
+
+        # prevent bigger axis from being more than max_size
+        if np.round(resize * im_size_max) > self.max_size:
+            resize = float(self.max_size) / float(im_size_max)
+        resize = 1 if use_origin_size else resize
+
+        # resize
+        if resize != 1:
+            if not from_PIL:
+                frames = F.interpolate(frames, scale_factor=resize)
+            else:
+                frames = [
+                    cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+                    for frame in frames
+                ]
+
+        # convert to torch.tensor format
+        if not from_PIL:
+            frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
+        else:
+            frames = frames.transpose((0, 3, 1, 2))
+            frames = torch.from_numpy(frames)
+
+        return frames, resize
+
+    def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
+        """
+        Arguments:
+            frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
+                type=np.uint8, BGR format).
+            conf_threshold: confidence threshold.
+            nms_threshold: nms threshold.
+            use_origin_size: whether to use origin size.
+        Returns:
+            final_bounding_boxes: list of np.array ([n_boxes, 5],
+                type=np.float32).
+            final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
+        """
+        # self.t['forward_pass'].tic()
+        frames, self.resize = self.batched_transform(frames, use_origin_size)
+        frames = frames.to(device)
+        frames = frames - self.mean_tensor
+
+        b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
+
+        final_bounding_boxes, final_landmarks = [], []
+
+        # decode
+        priors = priors.unsqueeze(0)
+        b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize
+        b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize
+        b_conf = b_conf[:, :, 1]
+
+        # index for selection
+        b_indice = b_conf > conf_threshold
+
+        # concat
+        b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
+
+        for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
+
+            # ignore low scores
+            pred, landm = pred[inds, :], landm[inds, :]
+            if pred.shape[0] == 0:
+                final_bounding_boxes.append(np.array([], dtype=np.float32))
+                final_landmarks.append(np.array([], dtype=np.float32))
+                continue
+
+            # sort
+            # order = score.argsort(descending=True)
+            # box, landm, score = box[order], landm[order], score[order]
+
+            # to CPU
+            bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
+
+            # NMS
+            keep = py_cpu_nms(bounding_boxes, nms_threshold)
+            bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
+
+            # append
+            final_bounding_boxes.append(bounding_boxes)
+            final_landmarks.append(landmarks)
+        # self.t['forward_pass'].toc(average=True)
+        # self.batch_time += self.t['forward_pass'].diff
+        # self.total_frame += len(frames)
+        # print(self.batch_time / self.total_frame)
+
+        return final_bounding_boxes, final_landmarks
diff --git a/CodeFormer/facelib/detection/retinaface/retinaface_net.py b/CodeFormer/facelib/detection/retinaface/retinaface_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6aa82d3e9055a838f1f9076b12f05fdfc154d0
--- /dev/null
+++ b/CodeFormer/facelib/detection/retinaface/retinaface_net.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv_bn(inp, oup, stride=1, leaky=0):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
+        nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_bn_no_relu(inp, oup, stride):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+        nn.BatchNorm2d(oup),
+    )
+
+
+def conv_bn1X1(inp, oup, stride, leaky=0):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
+        nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_dw(inp, oup, stride, leaky=0.1):
+    return nn.Sequential(
+        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+        nn.BatchNorm2d(inp),
+        nn.LeakyReLU(negative_slope=leaky, inplace=True),
+        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+        nn.BatchNorm2d(oup),
+        nn.LeakyReLU(negative_slope=leaky, inplace=True),
+    )
+
+
+class SSH(nn.Module):
+
+    def __init__(self, in_channel, out_channel):
+        super(SSH, self).__init__()
+        assert out_channel % 4 == 0
+        leaky = 0
+        if (out_channel <= 64):
+            leaky = 0.1
+        self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
+
+        self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
+        self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+        self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
+        self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+    def forward(self, input):
+        conv3X3 = self.conv3X3(input)
+
+        conv5X5_1 = self.conv5X5_1(input)
+        conv5X5 = self.conv5X5_2(conv5X5_1)
+
+        conv7X7_2 = self.conv7X7_2(conv5X5_1)
+        conv7X7 = self.conv7x7_3(conv7X7_2)
+
+        out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
+        out = F.relu(out)
+        return out
+
+
+class FPN(nn.Module):
+
+    def __init__(self, in_channels_list, out_channels):
+        super(FPN, self).__init__()
+        leaky = 0
+        if (out_channels <= 64):
+            leaky = 0.1
+        self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
+        self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
+        self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
+
+        self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
+        self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
+
+    def forward(self, input):
+        # names = list(input.keys())
+        # input = list(input.values())
+
+        output1 = self.output1(input[0])
+        output2 = self.output2(input[1])
+        output3 = self.output3(input[2])
+
+        up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
+        output2 = output2 + up3
+        output2 = self.merge2(output2)
+
+        up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
+        output1 = output1 + up2
+        output1 = self.merge1(output1)
+
+        out = [output1, output2, output3]
+        return out
+
+
+class MobileNetV1(nn.Module):
+
+    def __init__(self):
+        super(MobileNetV1, self).__init__()
+        self.stage1 = nn.Sequential(
+            conv_bn(3, 8, 2, leaky=0.1),  # 3
+            conv_dw(8, 16, 1),  # 7
+            conv_dw(16, 32, 2),  # 11
+            conv_dw(32, 32, 1),  # 19
+            conv_dw(32, 64, 2),  # 27
+            conv_dw(64, 64, 1),  # 43
+        )
+        self.stage2 = nn.Sequential(
+            conv_dw(64, 128, 2),  # 43 + 16 = 59
+            conv_dw(128, 128, 1),  # 59 + 32 = 91
+            conv_dw(128, 128, 1),  # 91 + 32 = 123
+            conv_dw(128, 128, 1),  # 123 + 32 = 155
+            conv_dw(128, 128, 1),  # 155 + 32 = 187
+            conv_dw(128, 128, 1),  # 187 + 32 = 219
+        )
+        self.stage3 = nn.Sequential(
+            conv_dw(128, 256, 2),  # 219 +3 2 = 241
+            conv_dw(256, 256, 1),  # 241 + 64 = 301
+        )
+        self.avg = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(256, 1000)
+
+    def forward(self, x):
+        x = self.stage1(x)
+        x = self.stage2(x)
+        x = self.stage3(x)
+        x = self.avg(x)
+        # x = self.model(x)
+        x = x.view(-1, 256)
+        x = self.fc(x)
+        return x
+
+
+class ClassHead(nn.Module):
+
+    def __init__(self, inchannels=512, num_anchors=3):
+        super(ClassHead, self).__init__()
+        self.num_anchors = num_anchors
+        self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
+
+    def forward(self, x):
+        out = self.conv1x1(x)
+        out = out.permute(0, 2, 3, 1).contiguous()
+
+        return out.view(out.shape[0], -1, 2)
+
+
+class BboxHead(nn.Module):
+
+    def __init__(self, inchannels=512, num_anchors=3):
+        super(BboxHead, self).__init__()
+        self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
+
+    def forward(self, x):
+        out = self.conv1x1(x)
+        out = out.permute(0, 2, 3, 1).contiguous()
+
+        return out.view(out.shape[0], -1, 4)
+
+
+class LandmarkHead(nn.Module):
+
+    def __init__(self, inchannels=512, num_anchors=3):
+        super(LandmarkHead, self).__init__()
+        self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
+
+    def forward(self, x):
+        out = self.conv1x1(x)
+        out = out.permute(0, 2, 3, 1).contiguous()
+
+        return out.view(out.shape[0], -1, 10)
+
+
+def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
+    classhead = nn.ModuleList()
+    for i in range(fpn_num):
+        classhead.append(ClassHead(inchannels, anchor_num))
+    return classhead
+
+
+def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
+    bboxhead = nn.ModuleList()
+    for i in range(fpn_num):
+        bboxhead.append(BboxHead(inchannels, anchor_num))
+    return bboxhead
+
+
+def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
+    landmarkhead = nn.ModuleList()
+    for i in range(fpn_num):
+        landmarkhead.append(LandmarkHead(inchannels, anchor_num))
+    return landmarkhead
diff --git a/CodeFormer/facelib/detection/retinaface/retinaface_utils.py b/CodeFormer/facelib/detection/retinaface/retinaface_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c357757741c6d9bd7ce4d8ce740fefd51850fbf
--- /dev/null
+++ b/CodeFormer/facelib/detection/retinaface/retinaface_utils.py
@@ -0,0 +1,421 @@
+import numpy as np
+import torch
+import torchvision
+from itertools import product as product
+from math import ceil
+
+
+class PriorBox(object):
+
+    def __init__(self, cfg, image_size=None, phase='train'):
+        super(PriorBox, self).__init__()
+        self.min_sizes = cfg['min_sizes']
+        self.steps = cfg['steps']
+        self.clip = cfg['clip']
+        self.image_size = image_size
+        self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
+        self.name = 's'
+
+    def forward(self):
+        anchors = []
+        for k, f in enumerate(self.feature_maps):
+            min_sizes = self.min_sizes[k]
+            for i, j in product(range(f[0]), range(f[1])):
+                for min_size in min_sizes:
+                    s_kx = min_size / self.image_size[1]
+                    s_ky = min_size / self.image_size[0]
+                    dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
+                    dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
+                    for cy, cx in product(dense_cy, dense_cx):
+                        anchors += [cx, cy, s_kx, s_ky]
+
+        # back to torch land
+        output = torch.Tensor(anchors).view(-1, 4)
+        if self.clip:
+            output.clamp_(max=1, min=0)
+        return output
+
+
+def py_cpu_nms(dets, thresh):
+    """Pure Python NMS baseline."""
+    keep = torchvision.ops.nms(
+        boxes=torch.Tensor(dets[:, :4]),
+        scores=torch.Tensor(dets[:, 4]),
+        iou_threshold=thresh,
+    )
+
+    return list(keep)
+
+
+def point_form(boxes):
+    """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
+    representation for comparison to point form ground truth data.
+    Args:
+        boxes: (tensor) center-size default boxes from priorbox layers.
+    Return:
+        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+    """
+    return torch.cat(
+        (
+            boxes[:, :2] - boxes[:, 2:] / 2,  # xmin, ymin
+            boxes[:, :2] + boxes[:, 2:] / 2),
+        1)  # xmax, ymax
+
+
+def center_size(boxes):
+    """ Convert prior_boxes to (cx, cy, w, h)
+    representation for comparison to center-size form ground truth data.
+    Args:
+        boxes: (tensor) point_form boxes
+    Return:
+        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+    """
+    return torch.cat(
+        (boxes[:, 2:] + boxes[:, :2]) / 2,  # cx, cy
+        boxes[:, 2:] - boxes[:, :2],
+        1)  # w, h
+
+
+def intersect(box_a, box_b):
+    """ We resize both tensors to [A,B,2] without new malloc:
+    [A,2] -> [A,1,2] -> [A,B,2]
+    [B,2] -> [1,B,2] -> [A,B,2]
+    Then we compute the area of intersect between box_a and box_b.
+    Args:
+      box_a: (tensor) bounding boxes, Shape: [A,4].
+      box_b: (tensor) bounding boxes, Shape: [B,4].
+    Return:
+      (tensor) intersection area, Shape: [A,B].
+    """
+    A = box_a.size(0)
+    B = box_b.size(0)
+    max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
+    min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
+    inter = torch.clamp((max_xy - min_xy), min=0)
+    return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+    """Compute the jaccard overlap of two sets of boxes.  The jaccard overlap
+    is simply the intersection over union of two boxes.  Here we operate on
+    ground truth boxes and default boxes.
+    E.g.:
+        A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+    Args:
+        box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
+        box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
+    Return:
+        jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
+    """
+    inter = intersect(box_a, box_b)
+    area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
+    area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
+    union = area_a + area_b - inter
+    return inter / union  # [A,B]
+
+
+def matrix_iou(a, b):
+    """
+    return iou of a and b, numpy version for data augenmentation
+    """
+    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+    area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+    return area_i / (area_a[:, np.newaxis] + area_b - area_i)
+
+
+def matrix_iof(a, b):
+    """
+    return iof of a and b, numpy version for data augenmentation
+    """
+    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+    return area_i / np.maximum(area_a[:, np.newaxis], 1)
+
+
+def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
+    """Match each prior box with the ground truth box of the highest jaccard
+    overlap, encode the bounding boxes, then return the matched indices
+    corresponding to both confidence and location preds.
+    Args:
+        threshold: (float) The overlap threshold used when matching boxes.
+        truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
+        priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
+        variances: (tensor) Variances corresponding to each prior coord,
+            Shape: [num_priors, 4].
+        labels: (tensor) All the class labels for the image, Shape: [num_obj].
+        landms: (tensor) Ground truth landms, Shape [num_obj, 10].
+        loc_t: (tensor) Tensor to be filled w/ encoded location targets.
+        conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
+        landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
+        idx: (int) current batch index
+    Return:
+        The matched indices corresponding to 1)location 2)confidence
+        3)landm preds.
+    """
+    # jaccard index
+    overlaps = jaccard(truths, point_form(priors))
+    # (Bipartite Matching)
+    # [1,num_objects] best prior for each ground truth
+    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
+
+    # ignore hard gt
+    valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
+    best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
+    if best_prior_idx_filter.shape[0] <= 0:
+        loc_t[idx] = 0
+        conf_t[idx] = 0
+        return
+
+    # [1,num_priors] best ground truth for each prior
+    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
+    best_truth_idx.squeeze_(0)
+    best_truth_overlap.squeeze_(0)
+    best_prior_idx.squeeze_(1)
+    best_prior_idx_filter.squeeze_(1)
+    best_prior_overlap.squeeze_(1)
+    best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2)  # ensure best prior
+    # TODO refactor: index  best_prior_idx with long tensor
+    # ensure every gt matches with its prior of max overlap
+    for j in range(best_prior_idx.size(0)):  # 判别此anchor是预测哪一个boxes
+        best_truth_idx[best_prior_idx[j]] = j
+    matches = truths[best_truth_idx]  # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
+    conf = labels[best_truth_idx]  # Shape: [num_priors]      此处为每一个anchor对应的label取出来
+    conf[best_truth_overlap < threshold] = 0  # label as background   overlap<0.35的全部作为负样本
+    loc = encode(matches, priors, variances)
+
+    matches_landm = landms[best_truth_idx]
+    landm = encode_landm(matches_landm, priors, variances)
+    loc_t[idx] = loc  # [num_priors,4] encoded offsets to learn
+    conf_t[idx] = conf  # [num_priors] top class label for each prior
+    landm_t[idx] = landm
+
+
+def encode(matched, priors, variances):
+    """Encode the variances from the priorbox layers into the ground truth boxes
+    we have matched (based on jaccard overlap) with the prior boxes.
+    Args:
+        matched: (tensor) Coords of ground truth for each prior in point-form
+            Shape: [num_priors, 4].
+        priors: (tensor) Prior boxes in center-offset form
+            Shape: [num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        encoded boxes (tensor), Shape: [num_priors, 4]
+    """
+
+    # dist b/t match center and prior's center
+    g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
+    # encode variance
+    g_cxcy /= (variances[0] * priors[:, 2:])
+    # match wh / prior wh
+    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+    g_wh = torch.log(g_wh) / variances[1]
+    # return target for smooth_l1_loss
+    return torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4]
+
+
+def encode_landm(matched, priors, variances):
+    """Encode the variances from the priorbox layers into the ground truth boxes
+    we have matched (based on jaccard overlap) with the prior boxes.
+    Args:
+        matched: (tensor) Coords of ground truth for each prior in point-form
+            Shape: [num_priors, 10].
+        priors: (tensor) Prior boxes in center-offset form
+            Shape: [num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        encoded landm (tensor), Shape: [num_priors, 10]
+    """
+
+    # dist b/t match center and prior's center
+    matched = torch.reshape(matched, (matched.size(0), 5, 2))
+    priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+    priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+    priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+    priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+    priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
+    g_cxcy = matched[:, :, :2] - priors[:, :, :2]
+    # encode variance
+    g_cxcy /= (variances[0] * priors[:, :, 2:])
+    # g_cxcy /= priors[:, :, 2:]
+    g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
+    # return target for smooth_l1_loss
+    return g_cxcy
+
+
+# Adapted from https://github.com/Hakuyume/chainer-ssd
+def decode(loc, priors, variances):
+    """Decode locations from predictions using priors to undo
+    the encoding we did for offset regression at train time.
+    Args:
+        loc (tensor): location predictions for loc layers,
+            Shape: [num_priors,4]
+        priors (tensor): Prior boxes in center-offset form.
+            Shape: [num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        decoded bounding box predictions
+    """
+
+    boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+                       priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+    boxes[:, :2] -= boxes[:, 2:] / 2
+    boxes[:, 2:] += boxes[:, :2]
+    return boxes
+
+
+def decode_landm(pre, priors, variances):
+    """Decode landm from predictions using priors to undo
+    the encoding we did for offset regression at train time.
+    Args:
+        pre (tensor): landm predictions for loc layers,
+            Shape: [num_priors,10]
+        priors (tensor): Prior boxes in center-offset form.
+            Shape: [num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        decoded landm predictions
+    """
+    tmp = (
+        priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
+        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
+        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
+        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
+        priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
+    )
+    landms = torch.cat(tmp, dim=1)
+    return landms
+
+
+def batched_decode(b_loc, priors, variances):
+    """Decode locations from predictions using priors to undo
+    the encoding we did for offset regression at train time.
+    Args:
+        b_loc (tensor): location predictions for loc layers,
+            Shape: [num_batches,num_priors,4]
+        priors (tensor): Prior boxes in center-offset form.
+            Shape: [1,num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        decoded bounding box predictions
+    """
+    boxes = (
+        priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
+        priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
+    )
+    boxes = torch.cat(boxes, dim=2)
+
+    boxes[:, :, :2] -= boxes[:, :, 2:] / 2
+    boxes[:, :, 2:] += boxes[:, :, :2]
+    return boxes
+
+
+def batched_decode_landm(pre, priors, variances):
+    """Decode landm from predictions using priors to undo
+    the encoding we did for offset regression at train time.
+    Args:
+        pre (tensor): landm predictions for loc layers,
+            Shape: [num_batches,num_priors,10]
+        priors (tensor): Prior boxes in center-offset form.
+            Shape: [1,num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        decoded landm predictions
+    """
+    landms = (
+        priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
+        priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
+        priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
+        priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
+        priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
+    )
+    landms = torch.cat(landms, dim=2)
+    return landms
+
+
+def log_sum_exp(x):
+    """Utility function for computing log_sum_exp while determining
+    This will be used to determine unaveraged confidence loss across
+    all examples in a batch.
+    Args:
+        x (Variable(tensor)): conf_preds from conf layers
+    """
+    x_max = x.data.max()
+    return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
+
+
+# Original author: Francisco Massa:
+# https://github.com/fmassa/object-detection.torch
+# Ported to PyTorch by Max deGroot (02/01/2017)
+def nms(boxes, scores, overlap=0.5, top_k=200):
+    """Apply non-maximum suppression at test time to avoid detecting too many
+    overlapping bounding boxes for a given object.
+    Args:
+        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+        scores: (tensor) The class predscores for the img, Shape:[num_priors].
+        overlap: (float) The overlap thresh for suppressing unnecessary boxes.
+        top_k: (int) The Maximum number of box preds to consider.
+    Return:
+        The indices of the kept boxes with respect to num_priors.
+    """
+
+    keep = torch.Tensor(scores.size(0)).fill_(0).long()
+    if boxes.numel() == 0:
+        return keep
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+    area = torch.mul(x2 - x1, y2 - y1)
+    v, idx = scores.sort(0)  # sort in ascending order
+    # I = I[v >= 0.01]
+    idx = idx[-top_k:]  # indices of the top-k largest vals
+    xx1 = boxes.new()
+    yy1 = boxes.new()
+    xx2 = boxes.new()
+    yy2 = boxes.new()
+    w = boxes.new()
+    h = boxes.new()
+
+    # keep = torch.Tensor()
+    count = 0
+    while idx.numel() > 0:
+        i = idx[-1]  # index of current largest val
+        # keep.append(i)
+        keep[count] = i
+        count += 1
+        if idx.size(0) == 1:
+            break
+        idx = idx[:-1]  # remove kept element from view
+        # load bboxes of next highest vals
+        torch.index_select(x1, 0, idx, out=xx1)
+        torch.index_select(y1, 0, idx, out=yy1)
+        torch.index_select(x2, 0, idx, out=xx2)
+        torch.index_select(y2, 0, idx, out=yy2)
+        # store element-wise max with next highest score
+        xx1 = torch.clamp(xx1, min=x1[i])
+        yy1 = torch.clamp(yy1, min=y1[i])
+        xx2 = torch.clamp(xx2, max=x2[i])
+        yy2 = torch.clamp(yy2, max=y2[i])
+        w.resize_as_(xx2)
+        h.resize_as_(yy2)
+        w = xx2 - xx1
+        h = yy2 - yy1
+        # check sizes of xx1 and xx2.. after each iteration
+        w = torch.clamp(w, min=0.0)
+        h = torch.clamp(h, min=0.0)
+        inter = w * h
+        # IoU = i / (area(a) + area(b) - i)
+        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
+        union = (rem_areas - inter) + area[i]
+        IoU = inter / union  # store result in iou
+        # keep only elements with an IoU <= overlap
+        idx = idx[IoU.le(overlap)]
+    return keep, count
diff --git a/CodeFormer/facelib/detection/yolov5face/__init__.py b/CodeFormer/facelib/detection/yolov5face/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/CodeFormer/facelib/detection/yolov5face/face_detector.py b/CodeFormer/facelib/detection/yolov5face/face_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..0103411e27860898fee470895a7cf59d8be2e11a
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/face_detector.py
@@ -0,0 +1,142 @@
+import copy
+import os
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import Conv
+from facelib.detection.yolov5face.models.yolo import Model
+from facelib.detection.yolov5face.utils.datasets import letterbox
+from facelib.detection.yolov5face.utils.general import (
+    check_img_size,
+    non_max_suppression_face,
+    scale_coords,
+    scale_coords_landmarks,
+)
+
+IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:3])) >= (1, 9, 0)
+
+
+def isListempty(inList):
+    if isinstance(inList, list): # Is a list
+        return all(map(isListempty, inList))
+    return False # Not a list
+
+class YoloDetector:
+    def __init__(
+        self,
+        config_name,
+        min_face=10,
+        target_size=None,
+        device='cuda',
+    ):
+        """
+        config_name: name of .yaml config with network configuration from models/ folder.
+        min_face : minimal face size in pixels.
+        target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
+                    None for original resolution.
+        """
+        self._class_path = Path(__file__).parent.absolute()
+        self.target_size = target_size
+        self.min_face = min_face
+        self.detector = Model(cfg=config_name)
+        self.device = device
+
+
+    def _preprocess(self, imgs):
+        """
+        Preprocessing image before passing through the network. Resize and conversion to torch tensor.
+        """
+        pp_imgs = []
+        for img in imgs:
+            h0, w0 = img.shape[:2]  # orig hw
+            if self.target_size:
+                r = self.target_size / min(h0, w0)  # resize image to img_size
+                if r < 1:
+                    img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
+
+            imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max())  # check img_size
+            img = letterbox(img, new_shape=imgsz)[0]
+            pp_imgs.append(img)
+        pp_imgs = np.array(pp_imgs)
+        pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
+        pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
+        pp_imgs = pp_imgs.float()  # uint8 to fp16/32
+        return pp_imgs / 255.0  # 0 - 255 to 0.0 - 1.0
+
+    def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
+        """
+        Postprocessing of raw pytorch model output.
+        Returns:
+            bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+            points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+        """
+        bboxes = [[] for _ in range(len(origimgs))]
+        landmarks = [[] for _ in range(len(origimgs))]
+
+        pred = non_max_suppression_face(pred, conf_thres, iou_thres)
+
+        for image_id, origimg in enumerate(origimgs):
+            img_shape = origimg.shape
+            image_height, image_width = img_shape[:2]
+            gn = torch.tensor(img_shape)[[1, 0, 1, 0]]  # normalization gain whwh
+            gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]]  # normalization gain landmarks
+            det = pred[image_id].cpu()
+            scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
+            scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round()
+
+            for j in range(det.size()[0]):
+                box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
+                box = list(
+                    map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
+                )
+                if box[3] - box[1] < self.min_face:
+                    continue
+                lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
+                lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
+                lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
+                bboxes[image_id].append(box)
+                landmarks[image_id].append(lm)
+        return bboxes, landmarks
+
+    def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
+        """
+        Get bbox coordinates and keypoints of faces on original image.
+        Params:
+            imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
+            conf_thres: confidence threshold for each prediction
+            iou_thres: threshold for NMS (filter of intersecting bboxes)
+        Returns:
+            bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+            points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+        """
+        # Pass input images through face detector
+        images = imgs if isinstance(imgs, list) else [imgs]
+        images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
+        origimgs = copy.deepcopy(images)
+
+        images = self._preprocess(images)
+        
+        if IS_HIGH_VERSION:
+            with torch.inference_mode():  # for pytorch>=1.9 
+                pred = self.detector(images)[0]
+        else:
+            with torch.no_grad():  # for pytorch<1.9
+                pred = self.detector(images)[0]
+
+        bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
+
+        # return bboxes, points
+        if not isListempty(points):
+            bboxes = np.array(bboxes).reshape(-1,4)
+            points = np.array(points).reshape(-1,10)
+            padding = bboxes[:,0].reshape(-1,1)
+            return np.concatenate((bboxes, padding, points), axis=1)
+        else:
+            return None
+
+    def __call__(self, *args):
+        return self.predict(*args)
diff --git a/CodeFormer/facelib/detection/yolov5face/models/__init__.py b/CodeFormer/facelib/detection/yolov5face/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/CodeFormer/facelib/detection/yolov5face/models/common.py b/CodeFormer/facelib/detection/yolov5face/models/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..497a00444c4c59725001993a63fe4617e9d323c8
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/models/common.py
@@ -0,0 +1,299 @@
+# This file contains modules common to various models
+
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.utils.datasets import letterbox
+from facelib.detection.yolov5face.utils.general import (
+    make_divisible,
+    non_max_suppression,
+    scale_coords,
+    xyxy2xywh,
+)
+
+
+def autopad(k, p=None):  # kernel, padding
+    # Pad to 'same'
+    if p is None:
+        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
+    return p
+
+
+def channel_shuffle(x, groups):
+    batchsize, num_channels, height, width = x.data.size()
+    channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc")
+
+    # reshape
+    x = x.view(batchsize, groups, channels_per_group, height, width)
+    x = torch.transpose(x, 1, 2).contiguous()
+
+    # flatten
+    return x.view(batchsize, -1, height, width)
+
+
+def DWConv(c1, c2, k=1, s=1, act=True):
+    # Depthwise convolution
+    return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
+
+
+class Conv(nn.Module):
+    # Standard convolution
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__()
+        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
+        self.bn = nn.BatchNorm2d(c2)
+        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
+
+    def forward(self, x):
+        return self.act(self.bn(self.conv(x)))
+
+    def fuseforward(self, x):
+        return self.act(self.conv(x))
+
+
+class StemBlock(nn.Module):
+    def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True):
+        super().__init__()
+        self.stem_1 = Conv(c1, c2, k, s, p, g, act)
+        self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0)
+        self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1)
+        self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
+        self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0)
+
+    def forward(self, x):
+        stem_1_out = self.stem_1(x)
+        stem_2a_out = self.stem_2a(stem_1_out)
+        stem_2b_out = self.stem_2b(stem_2a_out)
+        stem_2p_out = self.stem_2p(stem_1_out)
+        return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1))
+
+
+class Bottleneck(nn.Module):
+    # Standard bottleneck
+    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c_, c2, 3, 1, g=g)
+        self.add = shortcut and c1 == c2
+
+    def forward(self, x):
+        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class BottleneckCSP(nn.Module):
+    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
+        self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
+        self.cv4 = Conv(2 * c_, c2, 1, 1)
+        self.bn = nn.BatchNorm2d(2 * c_)  # applied to cat(cv2, cv3)
+        self.act = nn.LeakyReLU(0.1, inplace=True)
+        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+    def forward(self, x):
+        y1 = self.cv3(self.m(self.cv1(x)))
+        y2 = self.cv2(x)
+        return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
+
+
+class C3(nn.Module):
+    # CSP Bottleneck with 3 convolutions
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c1, c_, 1, 1)
+        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
+        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+    def forward(self, x):
+        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
+
+
+class ShuffleV2Block(nn.Module):
+    def __init__(self, inp, oup, stride):
+        super().__init__()
+
+        if not 1 <= stride <= 3:
+            raise ValueError("illegal stride value")
+        self.stride = stride
+
+        branch_features = oup // 2
+
+        if self.stride > 1:
+            self.branch1 = nn.Sequential(
+                self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
+                nn.BatchNorm2d(inp),
+                nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+                nn.BatchNorm2d(branch_features),
+                nn.SiLU(),
+            )
+        else:
+            self.branch1 = nn.Sequential()
+
+        self.branch2 = nn.Sequential(
+            nn.Conv2d(
+                inp if (self.stride > 1) else branch_features,
+                branch_features,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+            ),
+            nn.BatchNorm2d(branch_features),
+            nn.SiLU(),
+            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
+            nn.BatchNorm2d(branch_features),
+            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+            nn.BatchNorm2d(branch_features),
+            nn.SiLU(),
+        )
+
+    @staticmethod
+    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
+        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
+
+    def forward(self, x):
+        if self.stride == 1:
+            x1, x2 = x.chunk(2, dim=1)
+            out = torch.cat((x1, self.branch2(x2)), dim=1)
+        else:
+            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
+        out = channel_shuffle(out, 2)
+        return out
+
+
+class SPP(nn.Module):
+    # Spatial pyramid pooling layer used in YOLOv3-SPP
+    def __init__(self, c1, c2, k=(5, 9, 13)):
+        super().__init__()
+        c_ = c1 // 2  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
+        self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
+
+    def forward(self, x):
+        x = self.cv1(x)
+        return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
+
+
+class Focus(nn.Module):
+    # Focus wh information into c-space
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__()
+        self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
+
+    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
+        return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
+
+
+class Concat(nn.Module):
+    # Concatenate a list of tensors along dimension
+    def __init__(self, dimension=1):
+        super().__init__()
+        self.d = dimension
+
+    def forward(self, x):
+        return torch.cat(x, self.d)
+
+
+class NMS(nn.Module):
+    # Non-Maximum Suppression (NMS) module
+    conf = 0.25  # confidence threshold
+    iou = 0.45  # IoU threshold
+    classes = None  # (optional list) filter by class
+
+    def forward(self, x):
+        return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
+
+
+class AutoShape(nn.Module):
+    # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
+    img_size = 640  # inference size (pixels)
+    conf = 0.25  # NMS confidence threshold
+    iou = 0.45  # NMS IoU threshold
+    classes = None  # (optional list) filter by class
+
+    def __init__(self, model):
+        super().__init__()
+        self.model = model.eval()
+
+    def autoshape(self):
+        print("autoShape already enabled, skipping... ")  # model already converted to model.autoshape()
+        return self
+
+    def forward(self, imgs, size=640, augment=False, profile=False):
+        # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
+        #   OpenCV:          = cv2.imread('image.jpg')[:,:,::-1]  # HWC BGR to RGB x(720,1280,3)
+        #   PIL:             = Image.open('image.jpg')  # HWC x(720,1280,3)
+        #   numpy:           = np.zeros((720,1280,3))  # HWC
+        #   torch:           = torch.zeros(16,3,720,1280)  # BCHW
+        #   multiple:        = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...]  # list of images
+
+        p = next(self.model.parameters())  # for device and type
+        if isinstance(imgs, torch.Tensor):  # torch
+            return self.model(imgs.to(p.device).type_as(p), augment, profile)  # inference
+
+        # Pre-process
+        n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs])  # number of images, list of images
+        shape0, shape1 = [], []  # image and inference shapes
+        for i, im in enumerate(imgs):
+            im = np.array(im)  # to numpy
+            if im.shape[0] < 5:  # image in CHW
+                im = im.transpose((1, 2, 0))  # reverse dataloader .transpose(2, 0, 1)
+            im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3)  # enforce 3ch input
+            s = im.shape[:2]  # HWC
+            shape0.append(s)  # image shape
+            g = size / max(s)  # gain
+            shape1.append([y * g for y in s])
+            imgs[i] = im  # update
+        shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)]  # inference shape
+        x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs]  # pad
+        x = np.stack(x, 0) if n > 1 else x[0][None]  # stack
+        x = np.ascontiguousarray(x.transpose((0, 3, 1, 2)))  # BHWC to BCHW
+        x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0  # uint8 to fp16/32
+
+        # Inference
+        with torch.no_grad():
+            y = self.model(x, augment, profile)[0]  # forward
+        y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)  # NMS
+
+        # Post-process
+        for i in range(n):
+            scale_coords(shape1, y[i][:, :4], shape0[i])
+
+        return Detections(imgs, y, self.names)
+
+
+class Detections:
+    # detections class for YOLOv5 inference results
+    def __init__(self, imgs, pred, names=None):
+        super().__init__()
+        d = pred[0].device  # device
+        gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs]  # normalizations
+        self.imgs = imgs  # list of images as numpy arrays
+        self.pred = pred  # list of tensors pred[0] = (xyxy, conf, cls)
+        self.names = names  # class names
+        self.xyxy = pred  # xyxy pixels
+        self.xywh = [xyxy2xywh(x) for x in pred]  # xywh pixels
+        self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)]  # xyxy normalized
+        self.xywhn = [x / g for x, g in zip(self.xywh, gn)]  # xywh normalized
+        self.n = len(self.pred)
+
+    def __len__(self):
+        return self.n
+
+    def tolist(self):
+        # return a list of Detections objects, i.e. 'for result in results.tolist():'
+        x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
+        for d in x:
+            for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]:
+                setattr(d, k, getattr(d, k)[0])  # pop out of list
+        return x
diff --git a/CodeFormer/facelib/detection/yolov5face/models/experimental.py b/CodeFormer/facelib/detection/yolov5face/models/experimental.py
new file mode 100644
index 0000000000000000000000000000000000000000..37ba4c4420789c92dc0e2aaeb3d5b64859ec728c
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/models/experimental.py
@@ -0,0 +1,45 @@
+# # This file contains experimental modules
+
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import Conv
+
+
+class CrossConv(nn.Module):
+    # Cross Convolution Downsample
+    def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
+        # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, (1, k), (1, s))
+        self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
+        self.add = shortcut and c1 == c2
+
+    def forward(self, x):
+        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class MixConv2d(nn.Module):
+    # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
+    def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
+        super().__init__()
+        groups = len(k)
+        if equal_ch:  # equal c_ per group
+            i = torch.linspace(0, groups - 1e-6, c2).floor()  # c2 indices
+            c_ = [(i == g).sum() for g in range(groups)]  # intermediate channels
+        else:  # equal weight.numel() per group
+            b = [c2] + [0] * groups
+            a = np.eye(groups + 1, groups, k=-1)
+            a -= np.roll(a, 1, axis=1)
+            a *= np.array(k) ** 2
+            a[0] = 1
+            c_ = np.linalg.lstsq(a, b, rcond=None)[0].round()  # solve for equal weight indices, ax = b
+
+        self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
+        self.bn = nn.BatchNorm2d(c2)
+        self.act = nn.LeakyReLU(0.1, inplace=True)
+
+    def forward(self, x):
+        return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
diff --git a/CodeFormer/facelib/detection/yolov5face/models/yolo.py b/CodeFormer/facelib/detection/yolov5face/models/yolo.py
new file mode 100644
index 0000000000000000000000000000000000000000..70845d972f0bcfd3632fcbac096b23e1b4d4d779
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/models/yolo.py
@@ -0,0 +1,235 @@
+import math
+from copy import deepcopy
+from pathlib import Path
+
+import torch
+import yaml  # for torch hub
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import (
+    C3,
+    NMS,
+    SPP,
+    AutoShape,
+    Bottleneck,
+    BottleneckCSP,
+    Concat,
+    Conv,
+    DWConv,
+    Focus,
+    ShuffleV2Block,
+    StemBlock,
+)
+from facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d
+from facelib.detection.yolov5face.utils.autoanchor import check_anchor_order
+from facelib.detection.yolov5face.utils.general import make_divisible
+from facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn
+
+
+class Detect(nn.Module):
+    stride = None  # strides computed during build
+    export = False  # onnx export
+
+    def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
+        super().__init__()
+        self.nc = nc  # number of classes
+        self.no = nc + 5 + 10  # number of outputs per anchor
+
+        self.nl = len(anchors)  # number of detection layers
+        self.na = len(anchors[0]) // 2  # number of anchors
+        self.grid = [torch.zeros(1)] * self.nl  # init grid
+        a = torch.tensor(anchors).float().view(self.nl, -1, 2)
+        self.register_buffer("anchors", a)  # shape(nl,na,2)
+        self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2))  # shape(nl,1,na,1,1,2)
+        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
+
+    def forward(self, x):
+        z = []  # inference output
+        if self.export:
+            for i in range(self.nl):
+                x[i] = self.m[i](x[i])
+            return x
+        for i in range(self.nl):
+            x[i] = self.m[i](x[i])  # conv
+            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
+            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
+
+            if not self.training:  # inference
+                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
+                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
+
+                y = torch.full_like(x[i], 0)
+                y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid()
+                y[..., 5:15] = x[i][..., 5:15]
+
+                y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i]  # xy
+                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
+
+                y[..., 5:7] = (
+                    y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+                )  # landmark x1 y1
+                y[..., 7:9] = (
+                    y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+                )  # landmark x2 y2
+                y[..., 9:11] = (
+                    y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+                )  # landmark x3 y3
+                y[..., 11:13] = (
+                    y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+                )  # landmark x4 y4
+                y[..., 13:15] = (
+                    y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+                )  # landmark x5 y5
+
+                z.append(y.view(bs, -1, self.no))
+
+        return x if self.training else (torch.cat(z, 1), x)
+
+    @staticmethod
+    def _make_grid(nx=20, ny=20):
+        # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10
+        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
+        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
+
+
+class Model(nn.Module):
+    def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None):  # model, input channels, number of classes
+        super().__init__()
+        self.yaml_file = Path(cfg).name
+        with Path(cfg).open(encoding="utf8") as f:
+            self.yaml = yaml.safe_load(f)  # model dict
+
+        # Define model
+        ch = self.yaml["ch"] = self.yaml.get("ch", ch)  # input channels
+        if nc and nc != self.yaml["nc"]:
+            self.yaml["nc"] = nc  # override yaml value
+
+        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])  # model, savelist
+        self.names = [str(i) for i in range(self.yaml["nc"])]  # default names
+
+        # Build strides, anchors
+        m = self.model[-1]  # Detect()
+        if isinstance(m, Detect):
+            s = 128  # 2x min stride
+            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
+            m.anchors /= m.stride.view(-1, 1, 1)
+            check_anchor_order(m)
+            self.stride = m.stride
+            self._initialize_biases()  # only run once
+
+    def forward(self, x):
+        return self.forward_once(x)  # single-scale inference, train
+
+    def forward_once(self, x):
+        y = []  # outputs
+        for m in self.model:
+            if m.f != -1:  # if not from previous layer
+                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
+
+            x = m(x)  # run
+            y.append(x if m.i in self.save else None)  # save output
+
+        return x
+
+    def _initialize_biases(self, cf=None):  # initialize biases into Detect(), cf is class frequency
+        # https://arxiv.org/abs/1708.02002 section 3.3
+        m = self.model[-1]  # Detect() module
+        for mi, s in zip(m.m, m.stride):  # from
+            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
+            b.data[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
+            b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())  # cls
+            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+    def _print_biases(self):
+        m = self.model[-1]  # Detect() module
+        for mi in m.m:  # from
+            b = mi.bias.detach().view(m.na, -1).T  # conv.bias(255) to (3,85)
+            print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
+
+    def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layers
+        print("Fusing layers... ")
+        for m in self.model.modules():
+            if isinstance(m, Conv) and hasattr(m, "bn"):
+                m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
+                delattr(m, "bn")  # remove batchnorm
+                m.forward = m.fuseforward  # update forward
+            elif type(m) is nn.Upsample:
+                m.recompute_scale_factor = None  # torch 1.11.0 compatibility
+        return self
+
+    def nms(self, mode=True):  # add or remove NMS module
+        present = isinstance(self.model[-1], NMS)  # last layer is NMS
+        if mode and not present:
+            print("Adding NMS... ")
+            m = NMS()  # module
+            m.f = -1  # from
+            m.i = self.model[-1].i + 1  # index
+            self.model.add_module(name=str(m.i), module=m)  # add
+            self.eval()
+        elif not mode and present:
+            print("Removing NMS... ")
+            self.model = self.model[:-1]  # remove
+        return self
+
+    def autoshape(self):  # add autoShape module
+        print("Adding autoShape... ")
+        m = AutoShape(self)  # wrap model
+        copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=())  # copy attributes
+        return m
+
+
+def parse_model(d, ch):  # model_dict, input_channels(3)
+    anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"]
+    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
+    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)
+
+    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
+    for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, args
+        m = eval(m) if isinstance(m, str) else m  # eval strings
+        for j, a in enumerate(args):
+            try:
+                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
+            except:
+                pass
+
+        n = max(round(n * gd), 1) if n > 1 else n  # depth gain
+        if m in [
+            Conv,
+            Bottleneck,
+            SPP,
+            DWConv,
+            MixConv2d,
+            Focus,
+            CrossConv,
+            BottleneckCSP,
+            C3,
+            ShuffleV2Block,
+            StemBlock,
+        ]:
+            c1, c2 = ch[f], args[0]
+
+            c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
+
+            args = [c1, c2, *args[1:]]
+            if m in [BottleneckCSP, C3]:
+                args.insert(2, n)
+                n = 1
+        elif m is nn.BatchNorm2d:
+            args = [ch[f]]
+        elif m is Concat:
+            c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
+        elif m is Detect:
+            args.append([ch[x + 1] for x in f])
+            if isinstance(args[1], int):  # number of anchors
+                args[1] = [list(range(args[1] * 2))] * len(f)
+        else:
+            c2 = ch[f]
+
+        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
+        t = str(m)[8:-2].replace("__main__.", "")  # module type
+        np = sum(x.numel() for x in m_.parameters())  # number params
+        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
+        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
+        layers.append(m_)
+        ch.append(c2)
+    return nn.Sequential(*layers), sorted(save)
diff --git a/CodeFormer/facelib/detection/yolov5face/models/yolov5l.yaml b/CodeFormer/facelib/detection/yolov5face/models/yolov5l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0532b0e22fa7f59349b178146ffddcfdb368aba6
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/models/yolov5l.yaml
@@ -0,0 +1,47 @@
+# parameters
+nc: 1  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+
+# anchors
+anchors:
+  - [4,5,  8,10,  13,16]  # P3/8
+  - [23,29,  43,55,  73,105]  # P4/16
+  - [146,217,  231,300,  335,433]  # P5/32
+
+# YOLOv5 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, StemBlock, [64, 3, 2]],  # 0-P1/2
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],      # 2-P3/8
+   [-1, 9, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],      # 4-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],     # 6-P5/32
+   [-1, 1, SPP, [1024, [3,5,7]]],
+   [-1, 3, C3, [1024, False]],      # 8
+  ]
+
+# YOLOv5 head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 5], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 12
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 3], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 16 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 13], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 19 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 9], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 22 (P5/32-large)
+
+   [[16, 19, 22], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]
\ No newline at end of file
diff --git a/CodeFormer/facelib/detection/yolov5face/models/yolov5n.yaml b/CodeFormer/facelib/detection/yolov5face/models/yolov5n.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..caba6bed674aa2213b110f19e04eb352ffbeaf1e
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/models/yolov5n.yaml
@@ -0,0 +1,45 @@
+# parameters
+nc: 1  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+
+# anchors
+anchors:
+  - [4,5,  8,10,  13,16]  # P3/8
+  - [23,29,  43,55,  73,105]  # P4/16
+  - [146,217,  231,300,  335,433]  # P5/32
+
+# YOLOv5 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, StemBlock, [32, 3, 2]],    # 0-P2/4
+   [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8
+   [-1, 3, ShuffleV2Block, [128, 1]], # 2
+   [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16
+   [-1, 7, ShuffleV2Block, [256, 1]], # 4
+   [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32
+   [-1, 3, ShuffleV2Block, [512, 1]], # 6
+  ]
+
+# YOLOv5 head
+head:
+  [[-1, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P4
+   [-1, 1, C3, [128, False]],  # 10
+
+   [-1, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 2], 1, Concat, [1]],  # cat backbone P3
+   [-1, 1, C3, [128, False]],  # 14 (P3/8-small)
+
+   [-1, 1, Conv, [128, 3, 2]],
+   [[-1, 11], 1, Concat, [1]],  # cat head P4
+   [-1, 1, C3, [128, False]],  # 17 (P4/16-medium)
+
+   [-1, 1, Conv, [128, 3, 2]],
+   [[-1, 7], 1, Concat, [1]],  # cat head P5
+   [-1, 1, C3, [128, False]],  # 20 (P5/32-large)
+
+   [[14, 17, 20], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]
diff --git a/CodeFormer/facelib/detection/yolov5face/utils/__init__.py b/CodeFormer/facelib/detection/yolov5face/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/CodeFormer/facelib/detection/yolov5face/utils/autoanchor.py b/CodeFormer/facelib/detection/yolov5face/utils/autoanchor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4eba3e94888709be7d2a7c7499fbcc1808b4a88
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/utils/autoanchor.py
@@ -0,0 +1,12 @@
+# Auto-anchor utils
+
+
+def check_anchor_order(m):
+    # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
+    a = m.anchor_grid.prod(-1).view(-1)  # anchor area
+    da = a[-1] - a[0]  # delta a
+    ds = m.stride[-1] - m.stride[0]  # delta s
+    if da.sign() != ds.sign():  # same order
+        print("Reversing anchor order")
+        m.anchors[:] = m.anchors.flip(0)
+        m.anchor_grid[:] = m.anchor_grid.flip(0)
diff --git a/CodeFormer/facelib/detection/yolov5face/utils/datasets.py b/CodeFormer/facelib/detection/yolov5face/utils/datasets.py
new file mode 100755
index 0000000000000000000000000000000000000000..e672b136f56fd6b05038e24377908361a54fe519
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/utils/datasets.py
@@ -0,0 +1,35 @@
+import cv2
+import numpy as np
+
+
+def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True):
+    # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
+    shape = img.shape[:2]  # current shape [height, width]
+    if isinstance(new_shape, int):
+        new_shape = (new_shape, new_shape)
+
+    # Scale ratio (new / old)
+    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+    if not scaleup:  # only scale down, do not scale up (for better test mAP)
+        r = min(r, 1.0)
+
+    # Compute padding
+    ratio = r, r  # width, height ratios
+    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
+    if auto:  # minimum rectangle
+        dw, dh = np.mod(dw, 64), np.mod(dh, 64)  # wh padding
+    elif scale_fill:  # stretch
+        dw, dh = 0.0, 0.0
+        new_unpad = (new_shape[1], new_shape[0])
+        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios
+
+    dw /= 2  # divide padding into 2 sides
+    dh /= 2
+
+    if shape[::-1] != new_unpad:  # resize
+        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
+    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
+    return img, ratio, (dw, dh)
diff --git a/CodeFormer/facelib/detection/yolov5face/utils/extract_ckpt.py b/CodeFormer/facelib/detection/yolov5face/utils/extract_ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b8b631348f2d0cdea4e5a3594bb59f3e8f34a0f
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/utils/extract_ckpt.py
@@ -0,0 +1,5 @@
+import torch
+import sys
+sys.path.insert(0,'./facelib/detection/yolov5face')
+model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model']
+torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth')
\ No newline at end of file
diff --git a/CodeFormer/facelib/detection/yolov5face/utils/general.py b/CodeFormer/facelib/detection/yolov5face/utils/general.py
new file mode 100755
index 0000000000000000000000000000000000000000..1c8e14f56a107ec3a4269c382cfc5168ad780ffc
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/utils/general.py
@@ -0,0 +1,271 @@
+import math
+import time
+
+import numpy as np
+import torch
+import torchvision
+
+
+def check_img_size(img_size, s=32):
+    # Verify img_size is a multiple of stride s
+    new_size = make_divisible(img_size, int(s))  # ceil gs-multiple
+    # if new_size != img_size:
+    #     print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}")
+    return new_size
+
+
+def make_divisible(x, divisor):
+    # Returns x evenly divisible by divisor
+    return math.ceil(x / divisor) * divisor
+
+
+def xyxy2xywh(x):
+    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
+    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
+    y[:, 2] = x[:, 2] - x[:, 0]  # width
+    y[:, 3] = x[:, 3] - x[:, 1]  # height
+    return y
+
+
+def xywh2xyxy(x):
+    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
+    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
+    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
+    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
+    return y
+
+
+def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
+    # Rescale coords (xyxy) from img1_shape to img0_shape
+    if ratio_pad is None:  # calculate from img0_shape
+        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
+        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
+    else:
+        gain = ratio_pad[0][0]
+        pad = ratio_pad[1]
+
+    coords[:, [0, 2]] -= pad[0]  # x padding
+    coords[:, [1, 3]] -= pad[1]  # y padding
+    coords[:, :4] /= gain
+    clip_coords(coords, img0_shape)
+    return coords
+
+
+def clip_coords(boxes, img_shape):
+    # Clip bounding xyxy bounding boxes to image shape (height, width)
+    boxes[:, 0].clamp_(0, img_shape[1])  # x1
+    boxes[:, 1].clamp_(0, img_shape[0])  # y1
+    boxes[:, 2].clamp_(0, img_shape[1])  # x2
+    boxes[:, 3].clamp_(0, img_shape[0])  # y2
+
+
+def box_iou(box1, box2):
+    # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+    """
+    Return intersection-over-union (Jaccard index) of boxes.
+    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+    Arguments:
+        box1 (Tensor[N, 4])
+        box2 (Tensor[M, 4])
+    Returns:
+        iou (Tensor[N, M]): the NxM matrix containing the pairwise
+            IoU values for every element in boxes1 and boxes2
+    """
+
+    def box_area(box):
+        return (box[2] - box[0]) * (box[3] - box[1])
+
+    area1 = box_area(box1.T)
+    area2 = box_area(box2.T)
+
+    inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
+    return inter / (area1[:, None] + area2 - inter)
+
+
+def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+    """Performs Non-Maximum Suppression (NMS) on inference results
+    Returns:
+         detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+    """
+
+    nc = prediction.shape[2] - 15  # number of classes
+    xc = prediction[..., 4] > conf_thres  # candidates
+
+    # Settings
+    # (pixels) maximum box width and height
+    max_wh = 4096
+    time_limit = 10.0  # seconds to quit after
+    redundant = True  # require redundant detections
+    multi_label = nc > 1  # multiple labels per box (adds 0.5ms/img)
+    merge = False  # use merge-NMS
+
+    t = time.time()
+    output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
+    for xi, x in enumerate(prediction):  # image index, image inference
+        # Apply constraints
+        x = x[xc[xi]]  # confidence
+
+        # Cat apriori labels if autolabelling
+        if labels and len(labels[xi]):
+            label = labels[xi]
+            v = torch.zeros((len(label), nc + 15), device=x.device)
+            v[:, :4] = label[:, 1:5]  # box
+            v[:, 4] = 1.0  # conf
+            v[range(len(label)), label[:, 0].long() + 15] = 1.0  # cls
+            x = torch.cat((x, v), 0)
+
+        # If none remain process next image
+        if not x.shape[0]:
+            continue
+
+        # Compute conf
+        x[:, 15:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
+
+        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+        box = xywh2xyxy(x[:, :4])
+
+        # Detections matrix nx6 (xyxy, conf, landmarks, cls)
+        if multi_label:
+            i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
+            x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1)
+        else:  # best class only
+            conf, j = x[:, 15:].max(1, keepdim=True)
+            x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
+
+        # Filter by class
+        if classes is not None:
+            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+        # If none remain process next image
+        n = x.shape[0]  # number of boxes
+        if not n:
+            continue
+
+        # Batched NMS
+        c = x[:, 15:16] * (0 if agnostic else max_wh)  # classes
+        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
+        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
+
+        if merge and (1 < n < 3e3):  # Merge NMS (boxes merged using weighted mean)
+            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
+            weights = iou * scores[None]  # box weights
+            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
+            if redundant:
+                i = i[iou.sum(1) > 1]  # require redundancy
+
+        output[xi] = x[i]
+        if (time.time() - t) > time_limit:
+            break  # time limit exceeded
+
+    return output
+
+
+def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+    """Performs Non-Maximum Suppression (NMS) on inference results
+
+    Returns:
+         detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+    """
+
+    nc = prediction.shape[2] - 5  # number of classes
+    xc = prediction[..., 4] > conf_thres  # candidates
+
+    # Settings
+    # (pixels) maximum box width and height
+    max_wh = 4096
+    time_limit = 10.0  # seconds to quit after
+    redundant = True  # require redundant detections
+    multi_label = nc > 1  # multiple labels per box (adds 0.5ms/img)
+    merge = False  # use merge-NMS
+
+    t = time.time()
+    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
+    for xi, x in enumerate(prediction):  # image index, image inference
+        x = x[xc[xi]]  # confidence
+
+        # Cat apriori labels if autolabelling
+        if labels and len(labels[xi]):
+            label_id = labels[xi]
+            v = torch.zeros((len(label_id), nc + 5), device=x.device)
+            v[:, :4] = label_id[:, 1:5]  # box
+            v[:, 4] = 1.0  # conf
+            v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0  # cls
+            x = torch.cat((x, v), 0)
+
+        # If none remain process next image
+        if not x.shape[0]:
+            continue
+
+        # Compute conf
+        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
+
+        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+        box = xywh2xyxy(x[:, :4])
+
+        # Detections matrix nx6 (xyxy, conf, cls)
+        if multi_label:
+            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
+            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
+        else:  # best class only
+            conf, j = x[:, 5:].max(1, keepdim=True)
+            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
+
+        # Filter by class
+        if classes is not None:
+            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+        # Check shape
+        n = x.shape[0]  # number of boxes
+        if not n:  # no boxes
+            continue
+
+        x = x[x[:, 4].argsort(descending=True)]  # sort by confidence
+
+        # Batched NMS
+        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
+        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
+        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
+        if merge and (1 < n < 3e3):  # Merge NMS (boxes merged using weighted mean)
+            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
+            weights = iou * scores[None]  # box weights
+            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
+            if redundant:
+                i = i[iou.sum(1) > 1]  # require redundancy
+
+        output[xi] = x[i]
+        if (time.time() - t) > time_limit:
+            print(f"WARNING: NMS time limit {time_limit}s exceeded")
+            break  # time limit exceeded
+
+    return output
+
+
+def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
+    # Rescale coords (xyxy) from img1_shape to img0_shape
+    if ratio_pad is None:  # calculate from img0_shape
+        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
+        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
+    else:
+        gain = ratio_pad[0][0]
+        pad = ratio_pad[1]
+
+    coords[:, [0, 2, 4, 6, 8]] -= pad[0]  # x padding
+    coords[:, [1, 3, 5, 7, 9]] -= pad[1]  # y padding
+    coords[:, :10] /= gain
+    coords[:, 0].clamp_(0, img0_shape[1])  # x1
+    coords[:, 1].clamp_(0, img0_shape[0])  # y1
+    coords[:, 2].clamp_(0, img0_shape[1])  # x2
+    coords[:, 3].clamp_(0, img0_shape[0])  # y2
+    coords[:, 4].clamp_(0, img0_shape[1])  # x3
+    coords[:, 5].clamp_(0, img0_shape[0])  # y3
+    coords[:, 6].clamp_(0, img0_shape[1])  # x4
+    coords[:, 7].clamp_(0, img0_shape[0])  # y4
+    coords[:, 8].clamp_(0, img0_shape[1])  # x5
+    coords[:, 9].clamp_(0, img0_shape[0])  # y5
+    return coords
diff --git a/CodeFormer/facelib/detection/yolov5face/utils/torch_utils.py b/CodeFormer/facelib/detection/yolov5face/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af2d06587b2d07b2eab199a8484380fde1de5c3c
--- /dev/null
+++ b/CodeFormer/facelib/detection/yolov5face/utils/torch_utils.py
@@ -0,0 +1,40 @@
+import torch
+from torch import nn
+
+
+def fuse_conv_and_bn(conv, bn):
+    # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
+    fusedconv = (
+        nn.Conv2d(
+            conv.in_channels,
+            conv.out_channels,
+            kernel_size=conv.kernel_size,
+            stride=conv.stride,
+            padding=conv.padding,
+            groups=conv.groups,
+            bias=True,
+        )
+        .requires_grad_(False)
+        .to(conv.weight.device)
+    )
+
+    # prepare filters
+    w_conv = conv.weight.clone().view(conv.out_channels, -1)
+    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
+
+    # prepare spatial bias
+    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
+    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+    return fusedconv
+
+
+def copy_attr(a, b, include=(), exclude=()):
+    # Copy attributes from b to a, options to only include [...] and to exclude [...]
+    for k, v in b.__dict__.items():
+        if (include and k not in include) or k.startswith("_") or k in exclude:
+            continue
+
+        setattr(a, k, v)
diff --git a/CodeFormer/facelib/parsing/__init__.py b/CodeFormer/facelib/parsing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72656e4b5f61df8cd0838588b0c6488fcc886e16
--- /dev/null
+++ b/CodeFormer/facelib/parsing/__init__.py
@@ -0,0 +1,23 @@
+import torch
+
+from facelib.utils import load_file_from_url
+from .bisenet import BiSeNet
+from .parsenet import ParseNet
+
+
+def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
+    if model_name == 'bisenet':
+        model = BiSeNet(num_class=19)
+        model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth'
+    elif model_name == 'parsenet':
+        model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
+        model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
+    else:
+        raise NotImplementedError(f'{model_name} is not implemented.')
+
+    model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+    load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+    model.load_state_dict(load_net, strict=True)
+    model.eval()
+    model = model.to(device)
+    return model
diff --git a/CodeFormer/facelib/parsing/bisenet.py b/CodeFormer/facelib/parsing/bisenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3898cab76ae5876459cd4899c54cafa14234971d
--- /dev/null
+++ b/CodeFormer/facelib/parsing/bisenet.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .resnet import ResNet18
+
+
+class ConvBNReLU(nn.Module):
+
+    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
+        super(ConvBNReLU, self).__init__()
+        self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
+        self.bn = nn.BatchNorm2d(out_chan)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = F.relu(self.bn(x))
+        return x
+
+
+class BiSeNetOutput(nn.Module):
+
+    def __init__(self, in_chan, mid_chan, num_class):
+        super(BiSeNetOutput, self).__init__()
+        self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
+        self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
+
+    def forward(self, x):
+        feat = self.conv(x)
+        out = self.conv_out(feat)
+        return out, feat
+
+
+class AttentionRefinementModule(nn.Module):
+
+    def __init__(self, in_chan, out_chan):
+        super(AttentionRefinementModule, self).__init__()
+        self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
+        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
+        self.bn_atten = nn.BatchNorm2d(out_chan)
+        self.sigmoid_atten = nn.Sigmoid()
+
+    def forward(self, x):
+        feat = self.conv(x)
+        atten = F.avg_pool2d(feat, feat.size()[2:])
+        atten = self.conv_atten(atten)
+        atten = self.bn_atten(atten)
+        atten = self.sigmoid_atten(atten)
+        out = torch.mul(feat, atten)
+        return out
+
+
+class ContextPath(nn.Module):
+
+    def __init__(self):
+        super(ContextPath, self).__init__()
+        self.resnet = ResNet18()
+        self.arm16 = AttentionRefinementModule(256, 128)
+        self.arm32 = AttentionRefinementModule(512, 128)
+        self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+        self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+        self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
+
+    def forward(self, x):
+        feat8, feat16, feat32 = self.resnet(x)
+        h8, w8 = feat8.size()[2:]
+        h16, w16 = feat16.size()[2:]
+        h32, w32 = feat32.size()[2:]
+
+        avg = F.avg_pool2d(feat32, feat32.size()[2:])
+        avg = self.conv_avg(avg)
+        avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
+
+        feat32_arm = self.arm32(feat32)
+        feat32_sum = feat32_arm + avg_up
+        feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
+        feat32_up = self.conv_head32(feat32_up)
+
+        feat16_arm = self.arm16(feat16)
+        feat16_sum = feat16_arm + feat32_up
+        feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
+        feat16_up = self.conv_head16(feat16_up)
+
+        return feat8, feat16_up, feat32_up  # x8, x8, x16
+
+
+class FeatureFusionModule(nn.Module):
+
+    def __init__(self, in_chan, out_chan):
+        super(FeatureFusionModule, self).__init__()
+        self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
+        self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
+        self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
+        self.relu = nn.ReLU(inplace=True)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, fsp, fcp):
+        fcat = torch.cat([fsp, fcp], dim=1)
+        feat = self.convblk(fcat)
+        atten = F.avg_pool2d(feat, feat.size()[2:])
+        atten = self.conv1(atten)
+        atten = self.relu(atten)
+        atten = self.conv2(atten)
+        atten = self.sigmoid(atten)
+        feat_atten = torch.mul(feat, atten)
+        feat_out = feat_atten + feat
+        return feat_out
+
+
+class BiSeNet(nn.Module):
+
+    def __init__(self, num_class):
+        super(BiSeNet, self).__init__()
+        self.cp = ContextPath()
+        self.ffm = FeatureFusionModule(256, 256)
+        self.conv_out = BiSeNetOutput(256, 256, num_class)
+        self.conv_out16 = BiSeNetOutput(128, 64, num_class)
+        self.conv_out32 = BiSeNetOutput(128, 64, num_class)
+
+    def forward(self, x, return_feat=False):
+        h, w = x.size()[2:]
+        feat_res8, feat_cp8, feat_cp16 = self.cp(x)  # return res3b1 feature
+        feat_sp = feat_res8  # replace spatial path feature with res3b1 feature
+        feat_fuse = self.ffm(feat_sp, feat_cp8)
+
+        out, feat = self.conv_out(feat_fuse)
+        out16, feat16 = self.conv_out16(feat_cp8)
+        out32, feat32 = self.conv_out32(feat_cp16)
+
+        out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
+        out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
+        out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
+
+        if return_feat:
+            feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
+            feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
+            feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
+            return out, out16, out32, feat, feat16, feat32
+        else:
+            return out, out16, out32
diff --git a/CodeFormer/facelib/parsing/parsenet.py b/CodeFormer/facelib/parsing/parsenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e178ebe43a1ef666aaea0bc0faf629485c22a24f
--- /dev/null
+++ b/CodeFormer/facelib/parsing/parsenet.py
@@ -0,0 +1,194 @@
+"""Modified from https://github.com/chaofengc/PSFRGAN
+"""
+import numpy as np
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class NormLayer(nn.Module):
+    """Normalization Layers.
+
+    Args:
+        channels: input channels, for batch norm and instance norm.
+        input_size: input shape without batch size, for layer norm.
+    """
+
+    def __init__(self, channels, normalize_shape=None, norm_type='bn'):
+        super(NormLayer, self).__init__()
+        norm_type = norm_type.lower()
+        self.norm_type = norm_type
+        if norm_type == 'bn':
+            self.norm = nn.BatchNorm2d(channels, affine=True)
+        elif norm_type == 'in':
+            self.norm = nn.InstanceNorm2d(channels, affine=False)
+        elif norm_type == 'gn':
+            self.norm = nn.GroupNorm(32, channels, affine=True)
+        elif norm_type == 'pixel':
+            self.norm = lambda x: F.normalize(x, p=2, dim=1)
+        elif norm_type == 'layer':
+            self.norm = nn.LayerNorm(normalize_shape)
+        elif norm_type == 'none':
+            self.norm = lambda x: x * 1.0
+        else:
+            assert 1 == 0, f'Norm type {norm_type} not support.'
+
+    def forward(self, x, ref=None):
+        if self.norm_type == 'spade':
+            return self.norm(x, ref)
+        else:
+            return self.norm(x)
+
+
+class ReluLayer(nn.Module):
+    """Relu Layer.
+
+    Args:
+        relu type: type of relu layer, candidates are
+            - ReLU
+            - LeakyReLU: default relu slope 0.2
+            - PRelu
+            - SELU
+            - none: direct pass
+    """
+
+    def __init__(self, channels, relu_type='relu'):
+        super(ReluLayer, self).__init__()
+        relu_type = relu_type.lower()
+        if relu_type == 'relu':
+            self.func = nn.ReLU(True)
+        elif relu_type == 'leakyrelu':
+            self.func = nn.LeakyReLU(0.2, inplace=True)
+        elif relu_type == 'prelu':
+            self.func = nn.PReLU(channels)
+        elif relu_type == 'selu':
+            self.func = nn.SELU(True)
+        elif relu_type == 'none':
+            self.func = lambda x: x * 1.0
+        else:
+            assert 1 == 0, f'Relu type {relu_type} not support.'
+
+    def forward(self, x):
+        return self.func(x)
+
+
+class ConvLayer(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size=3,
+                 scale='none',
+                 norm_type='none',
+                 relu_type='none',
+                 use_pad=True,
+                 bias=True):
+        super(ConvLayer, self).__init__()
+        self.use_pad = use_pad
+        self.norm_type = norm_type
+        if norm_type in ['bn']:
+            bias = False
+
+        stride = 2 if scale == 'down' else 1
+
+        self.scale_func = lambda x: x
+        if scale == 'up':
+            self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
+
+        self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
+        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
+
+        self.relu = ReluLayer(out_channels, relu_type)
+        self.norm = NormLayer(out_channels, norm_type=norm_type)
+
+    def forward(self, x):
+        out = self.scale_func(x)
+        if self.use_pad:
+            out = self.reflection_pad(out)
+        out = self.conv2d(out)
+        out = self.norm(out)
+        out = self.relu(out)
+        return out
+
+
+class ResidualBlock(nn.Module):
+    """
+    Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
+    """
+
+    def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
+        super(ResidualBlock, self).__init__()
+
+        if scale == 'none' and c_in == c_out:
+            self.shortcut_func = lambda x: x
+        else:
+            self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
+
+        scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
+        scale_conf = scale_config_dict[scale]
+
+        self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
+        self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
+
+    def forward(self, x):
+        identity = self.shortcut_func(x)
+
+        res = self.conv1(x)
+        res = self.conv2(res)
+        return identity + res
+
+
+class ParseNet(nn.Module):
+
+    def __init__(self,
+                 in_size=128,
+                 out_size=128,
+                 min_feat_size=32,
+                 base_ch=64,
+                 parsing_ch=19,
+                 res_depth=10,
+                 relu_type='LeakyReLU',
+                 norm_type='bn',
+                 ch_range=[32, 256]):
+        super().__init__()
+        self.res_depth = res_depth
+        act_args = {'norm_type': norm_type, 'relu_type': relu_type}
+        min_ch, max_ch = ch_range
+
+        ch_clip = lambda x: max(min_ch, min(x, max_ch))  # noqa: E731
+        min_feat_size = min(in_size, min_feat_size)
+
+        down_steps = int(np.log2(in_size // min_feat_size))
+        up_steps = int(np.log2(out_size // min_feat_size))
+
+        # =============== define encoder-body-decoder ====================
+        self.encoder = []
+        self.encoder.append(ConvLayer(3, base_ch, 3, 1))
+        head_ch = base_ch
+        for i in range(down_steps):
+            cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
+            self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
+            head_ch = head_ch * 2
+
+        self.body = []
+        for i in range(res_depth):
+            self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
+
+        self.decoder = []
+        for i in range(up_steps):
+            cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
+            self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
+            head_ch = head_ch // 2
+
+        self.encoder = nn.Sequential(*self.encoder)
+        self.body = nn.Sequential(*self.body)
+        self.decoder = nn.Sequential(*self.decoder)
+        self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
+        self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
+
+    def forward(self, x):
+        feat = self.encoder(x)
+        x = feat + self.body(feat)
+        x = self.decoder(x)
+        out_img = self.out_img_conv(x)
+        out_mask = self.out_mask_conv(x)
+        return out_mask, out_img
diff --git a/CodeFormer/facelib/parsing/resnet.py b/CodeFormer/facelib/parsing/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec8e82cf64469fb51be21ad5130217052addbda
--- /dev/null
+++ b/CodeFormer/facelib/parsing/resnet.py
@@ -0,0 +1,69 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+
+    def __init__(self, in_chan, out_chan, stride=1):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(in_chan, out_chan, stride)
+        self.bn1 = nn.BatchNorm2d(out_chan)
+        self.conv2 = conv3x3(out_chan, out_chan)
+        self.bn2 = nn.BatchNorm2d(out_chan)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = None
+        if in_chan != out_chan or stride != 1:
+            self.downsample = nn.Sequential(
+                nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(out_chan),
+            )
+
+    def forward(self, x):
+        residual = self.conv1(x)
+        residual = F.relu(self.bn1(residual))
+        residual = self.conv2(residual)
+        residual = self.bn2(residual)
+
+        shortcut = x
+        if self.downsample is not None:
+            shortcut = self.downsample(x)
+
+        out = shortcut + residual
+        out = self.relu(out)
+        return out
+
+
+def create_layer_basic(in_chan, out_chan, bnum, stride=1):
+    layers = [BasicBlock(in_chan, out_chan, stride=stride)]
+    for i in range(bnum - 1):
+        layers.append(BasicBlock(out_chan, out_chan, stride=1))
+    return nn.Sequential(*layers)
+
+
+class ResNet18(nn.Module):
+
+    def __init__(self):
+        super(ResNet18, self).__init__()
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
+        self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
+        self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
+        self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = F.relu(self.bn1(x))
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        feat8 = self.layer2(x)  # 1/8
+        feat16 = self.layer3(feat8)  # 1/16
+        feat32 = self.layer4(feat16)  # 1/32
+        return feat8, feat16, feat32
diff --git a/CodeFormer/facelib/utils/__init__.py b/CodeFormer/facelib/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f03b1c2bafcd7759cb7e8722a0c6715f201a46dc
--- /dev/null
+++ b/CodeFormer/facelib/utils/__init__.py
@@ -0,0 +1,7 @@
+from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
+from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir
+
+__all__ = [
+    'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 
+    'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir'
+]
diff --git a/CodeFormer/facelib/utils/face_restoration_helper.py b/CodeFormer/facelib/utils/face_restoration_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d3fb8f3b95ed9959610e64f6d7373ea8a56ece8
--- /dev/null
+++ b/CodeFormer/facelib/utils/face_restoration_helper.py
@@ -0,0 +1,460 @@
+import cv2
+import numpy as np
+import os
+import torch
+from torchvision.transforms.functional import normalize
+
+from facelib.detection import init_detection_model
+from facelib.parsing import init_parsing_model
+from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
+
+
+def get_largest_face(det_faces, h, w):
+
+    def get_location(val, length):
+        if val < 0:
+            return 0
+        elif val > length:
+            return length
+        else:
+            return val
+
+    face_areas = []
+    for det_face in det_faces:
+        left = get_location(det_face[0], w)
+        right = get_location(det_face[2], w)
+        top = get_location(det_face[1], h)
+        bottom = get_location(det_face[3], h)
+        face_area = (right - left) * (bottom - top)
+        face_areas.append(face_area)
+    largest_idx = face_areas.index(max(face_areas))
+    return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+    if center is not None:
+        center = np.array(center)
+    else:
+        center = np.array([w / 2, h / 2])
+    center_dist = []
+    for det_face in det_faces:
+        face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+        dist = np.linalg.norm(face_center - center)
+        center_dist.append(dist)
+    center_idx = center_dist.index(min(center_dist))
+    return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper(object):
+    """Helper for the face restoration pipeline (base class)."""
+
+    def __init__(self,
+                 upscale_factor,
+                 face_size=512,
+                 crop_ratio=(1, 1),
+                 det_model='retinaface_resnet50',
+                 save_ext='png',
+                 template_3points=False,
+                 pad_blur=False,
+                 use_parse=False,
+                 device=None):
+        self.template_3points = template_3points  # improve robustness
+        self.upscale_factor = int(upscale_factor)
+        # the cropped face ratio based on the square face
+        self.crop_ratio = crop_ratio  # (h, w)
+        assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
+        self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+
+        if self.template_3points:
+            self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+        else:
+            # standard 5 landmarks for FFHQ faces with 512 x 512 
+            # facexlib
+            self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+                                           [201.26117, 371.41043], [313.08905, 371.15118]])
+
+            # dlib: left_eye: 36:41  right_eye: 42:47  nose: 30,32,33,34  left mouth corner: 48  right mouth corner: 54
+            # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+            #                                 [198.22603, 372.82502], [313.91018, 372.75659]])
+
+
+        self.face_template = self.face_template * (face_size / 512.0)
+        if self.crop_ratio[0] > 1:
+            self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
+        if self.crop_ratio[1] > 1:
+            self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
+        self.save_ext = save_ext
+        self.pad_blur = pad_blur
+        if self.pad_blur is True:
+            self.template_3points = False
+
+        self.all_landmarks_5 = []
+        self.det_faces = []
+        self.affine_matrices = []
+        self.inverse_affine_matrices = []
+        self.cropped_faces = []
+        self.restored_faces = []
+        self.pad_input_imgs = []
+
+        if device is None:
+            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        else:
+            self.device = device
+
+        # init face detection model
+        self.face_det = init_detection_model(det_model, half=False, device=self.device)
+
+        # init face parsing model
+        self.use_parse = use_parse
+        self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
+
+    def set_upscale_factor(self, upscale_factor):
+        self.upscale_factor = upscale_factor
+
+    def read_image(self, img):
+        """img can be image path or cv2 loaded image."""
+        # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+        if isinstance(img, str):
+            img = cv2.imread(img)
+
+        if np.max(img) > 256:  # 16-bit image
+            img = img / 65535 * 255
+        if len(img.shape) == 2:  # gray image
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+        elif img.shape[2] == 4:  # BGRA image with alpha channel
+            img = img[:, :, 0:3]
+
+        self.input_img = img
+        self.is_gray = is_gray(img, threshold=5)
+        if self.is_gray:
+            print('Grayscale input: True')
+
+        if min(self.input_img.shape[:2])<512:
+            f = 512.0/min(self.input_img.shape[:2])
+            self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+    def get_face_landmarks_5(self,
+                             only_keep_largest=False,
+                             only_center_face=False,
+                             resize=None,
+                             blur_ratio=0.01,
+                             eye_dist_threshold=None):
+        if resize is None:
+            scale = 1
+            input_img = self.input_img
+        else:
+            h, w = self.input_img.shape[0:2]
+            scale = resize / min(h, w)
+            scale = max(1, scale) # always scale up
+            h, w = int(h * scale), int(w * scale)
+            interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+            input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
+
+        with torch.no_grad():
+            bboxes = self.face_det.detect_faces(input_img)
+
+        if bboxes is None or bboxes.shape[0] == 0:
+            return 0
+        else:
+            bboxes = bboxes / scale
+
+        for bbox in bboxes:
+            # remove faces with too small eye distance: side faces or too small faces
+            eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+            if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+                continue
+
+            if self.template_3points:
+                landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
+            else:
+                landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
+            self.all_landmarks_5.append(landmark)
+            self.det_faces.append(bbox[0:5])
+            
+        if len(self.det_faces) == 0:
+            return 0
+        if only_keep_largest:
+            h, w, _ = self.input_img.shape
+            self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
+            self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+        elif only_center_face:
+            h, w, _ = self.input_img.shape
+            self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+            self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+        # pad blurry images
+        if self.pad_blur:
+            self.pad_input_imgs = []
+            for landmarks in self.all_landmarks_5:
+                # get landmarks
+                eye_left = landmarks[0, :]
+                eye_right = landmarks[1, :]
+                eye_avg = (eye_left + eye_right) * 0.5
+                mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+                eye_to_eye = eye_right - eye_left
+                eye_to_mouth = mouth_avg - eye_avg
+
+                # Get the oriented crop rectangle
+                # x: half width of the oriented crop rectangle
+                x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+                #  - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+                # norm with the hypotenuse: get the direction
+                x /= np.hypot(*x)  # get the hypotenuse of a right triangle
+                rect_scale = 1.5
+                x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+                # y: half height of the oriented crop rectangle
+                y = np.flipud(x) * [-1, 1]
+
+                # c: center
+                c = eye_avg + eye_to_mouth * 0.1
+                # quad: (left_top, left_bottom, right_bottom, right_top)
+                quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+                # qsize: side length of the square
+                qsize = np.hypot(*x) * 2
+                border = max(int(np.rint(qsize * 0.1)), 3)
+
+                # get pad
+                # pad: (width_left, height_top, width_right, height_bottom)
+                pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+                       int(np.ceil(max(quad[:, 1]))))
+                pad = [
+                    max(-pad[0] + border, 1),
+                    max(-pad[1] + border, 1),
+                    max(pad[2] - self.input_img.shape[0] + border, 1),
+                    max(pad[3] - self.input_img.shape[1] + border, 1)
+                ]
+
+                if max(pad) > 1:
+                    # pad image
+                    pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+                    # modify landmark coords
+                    landmarks[:, 0] += pad[0]
+                    landmarks[:, 1] += pad[1]
+                    # blur pad images
+                    h, w, _ = pad_img.shape
+                    y, x, _ = np.ogrid[:h, :w, :1]
+                    mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+                                                       np.float32(w - 1 - x) / pad[2]),
+                                      1.0 - np.minimum(np.float32(y) / pad[1],
+                                                       np.float32(h - 1 - y) / pad[3]))
+                    blur = int(qsize * blur_ratio)
+                    if blur % 2 == 0:
+                        blur += 1
+                    blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+                    # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+                    pad_img = pad_img.astype('float32')
+                    pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+                    pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
+                    pad_img = np.clip(pad_img, 0, 255)  # float32, [0, 255]
+                    self.pad_input_imgs.append(pad_img)
+                else:
+                    self.pad_input_imgs.append(np.copy(self.input_img))
+
+        return len(self.all_landmarks_5)
+
+    def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
+        """Align and warp faces with face template.
+        """
+        if self.pad_blur:
+            assert len(self.pad_input_imgs) == len(
+                self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
+        for idx, landmark in enumerate(self.all_landmarks_5):
+            # use 5 landmarks to get affine matrix
+            # use cv2.LMEDS method for the equivalence to skimage transform
+            # ref: https://blog.csdn.net/yichxi/article/details/115827338
+            affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
+            self.affine_matrices.append(affine_matrix)
+            # warp and crop faces
+            if border_mode == 'constant':
+                border_mode = cv2.BORDER_CONSTANT
+            elif border_mode == 'reflect101':
+                border_mode = cv2.BORDER_REFLECT101
+            elif border_mode == 'reflect':
+                border_mode = cv2.BORDER_REFLECT
+            if self.pad_blur:
+                input_img = self.pad_input_imgs[idx]
+            else:
+                input_img = self.input_img
+            cropped_face = cv2.warpAffine(
+                input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132))  # gray
+            self.cropped_faces.append(cropped_face)
+            # save the cropped face
+            if save_cropped_path is not None:
+                path = os.path.splitext(save_cropped_path)[0]
+                save_path = f'{path}_{idx:02d}.{self.save_ext}'
+                imwrite(cropped_face, save_path)
+
+    def get_inverse_affine(self, save_inverse_affine_path=None):
+        """Get inverse affine matrix."""
+        for idx, affine_matrix in enumerate(self.affine_matrices):
+            inverse_affine = cv2.invertAffineTransform(affine_matrix)
+            inverse_affine *= self.upscale_factor
+            self.inverse_affine_matrices.append(inverse_affine)
+            # save inverse affine matrices
+            if save_inverse_affine_path is not None:
+                path, _ = os.path.splitext(save_inverse_affine_path)
+                save_path = f'{path}_{idx:02d}.pth'
+                torch.save(inverse_affine, save_path)
+
+
+    def add_restored_face(self, face):
+        if self.is_gray:
+            face = bgr2gray(face) # convert img into grayscale
+        self.restored_faces.append(face)
+
+
+    def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+        h, w, _ = self.input_img.shape
+        h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+        if upsample_img is None:
+            # simply resize the background
+            # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+            upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+        else:
+            upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+        assert len(self.restored_faces) == len(
+            self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
+        
+        inv_mask_borders = []
+        for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
+            if face_upsampler is not None:
+                restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
+                inverse_affine /= self.upscale_factor
+                inverse_affine[:, 2] *= self.upscale_factor
+                face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
+            else:
+                # Add an offset to inverse affine matrix, for more precise back alignment
+                if self.upscale_factor > 1:
+                    extra_offset = 0.5 * self.upscale_factor
+                else:
+                    extra_offset = 0
+                inverse_affine[:, 2] += extra_offset
+                face_size = self.face_size
+            inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
+
+            # if draw_box or not self.use_parse:  # use square parse maps
+            #     mask = np.ones(face_size, dtype=np.float32)
+            #     inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+            #     # remove the black borders
+            #     inv_mask_erosion = cv2.erode(
+            #         inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+            #     pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+            #     total_face_area = np.sum(inv_mask_erosion)  # // 3
+            #     # add border
+            #     if draw_box:
+            #         h, w = face_size
+            #         mask_border = np.ones((h, w, 3), dtype=np.float32)
+            #         border = int(1400/np.sqrt(total_face_area))
+            #         mask_border[border:h-border, border:w-border,:] = 0
+            #         inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+            #         inv_mask_borders.append(inv_mask_border)
+            #     if not self.use_parse:
+            #         # compute the fusion edge based on the area of face
+            #         w_edge = int(total_face_area**0.5) // 20
+            #         erosion_radius = w_edge * 2
+            #         inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+            #         blur_size = w_edge * 2
+            #         inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+            #         if len(upsample_img.shape) == 2:  # upsample_img is gray image
+            #             upsample_img = upsample_img[:, :, None]
+            #         inv_soft_mask = inv_soft_mask[:, :, None]
+
+            # always use square mask
+            mask = np.ones(face_size, dtype=np.float32)
+            inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+            # remove the black borders
+            inv_mask_erosion = cv2.erode(
+                inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+            pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+            total_face_area = np.sum(inv_mask_erosion)  # // 3
+            # add border
+            if draw_box:
+                h, w = face_size
+                mask_border = np.ones((h, w, 3), dtype=np.float32)
+                border = int(1400/np.sqrt(total_face_area))
+                mask_border[border:h-border, border:w-border,:] = 0
+                inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+                inv_mask_borders.append(inv_mask_border)
+            # compute the fusion edge based on the area of face
+            w_edge = int(total_face_area**0.5) // 20
+            erosion_radius = w_edge * 2
+            inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+            blur_size = w_edge * 2
+            inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+            if len(upsample_img.shape) == 2:  # upsample_img is gray image
+                upsample_img = upsample_img[:, :, None]
+            inv_soft_mask = inv_soft_mask[:, :, None]
+
+            # parse mask
+            if self.use_parse:
+                # inference
+                face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+                face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
+                normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+                face_input = torch.unsqueeze(face_input, 0).to(self.device)
+                with torch.no_grad():
+                    out = self.face_parse(face_input)[0]
+                out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+                parse_mask = np.zeros(out.shape)
+                MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+                for idx, color in enumerate(MASK_COLORMAP):
+                    parse_mask[out == idx] = color
+                #  blur the mask
+                parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+                parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+                # remove the black borders
+                thres = 10
+                parse_mask[:thres, :] = 0
+                parse_mask[-thres:, :] = 0
+                parse_mask[:, :thres] = 0
+                parse_mask[:, -thres:] = 0
+                parse_mask = parse_mask / 255.
+
+                parse_mask = cv2.resize(parse_mask, face_size)
+                parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
+                inv_soft_parse_mask = parse_mask[:, :, None]
+                # pasted_face = inv_restored
+                fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
+                inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
+
+            if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4:  # alpha channel
+                alpha = upsample_img[:, :, 3:]
+                upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
+                upsample_img = np.concatenate((upsample_img, alpha), axis=2)
+            else:
+                upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
+
+        if np.max(upsample_img) > 256:  # 16-bit image
+            upsample_img = upsample_img.astype(np.uint16)
+        else:
+            upsample_img = upsample_img.astype(np.uint8)
+
+        # draw bounding box
+        if draw_box:
+            # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+            img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+            img_color[:,:,0] = 0
+            img_color[:,:,1] = 255
+            img_color[:,:,2] = 0
+            for inv_mask_border in inv_mask_borders:
+                upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
+                # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+        if save_path is not None:
+            path = os.path.splitext(save_path)[0]
+            save_path = f'{path}.{self.save_ext}'
+            imwrite(upsample_img, save_path)
+        return upsample_img
+
+    def clean_all(self):
+        self.all_landmarks_5 = []
+        self.restored_faces = []
+        self.affine_matrices = []
+        self.cropped_faces = []
+        self.inverse_affine_matrices = []
+        self.det_faces = []
+        self.pad_input_imgs = []
\ No newline at end of file
diff --git a/CodeFormer/facelib/utils/face_utils.py b/CodeFormer/facelib/utils/face_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1474a2a4419b6b62fab8a919ef805b802556464
--- /dev/null
+++ b/CodeFormer/facelib/utils/face_utils.py
@@ -0,0 +1,248 @@
+import cv2
+import numpy as np
+import torch
+
+
+def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
+    left, top, right, bot = bbox
+    width = right - left
+    height = bot - top
+
+    if preserve_aspect:
+        width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
+        height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
+    else:
+        width_increase = height_increase = increase_area
+    left = int(left - width_increase * width)
+    top = int(top - height_increase * height)
+    right = int(right + width_increase * width)
+    bot = int(bot + height_increase * height)
+    return (left, top, right, bot)
+
+
+def get_valid_bboxes(bboxes, h, w):
+    left = max(bboxes[0], 0)
+    top = max(bboxes[1], 0)
+    right = min(bboxes[2], w)
+    bottom = min(bboxes[3], h)
+    return (left, top, right, bottom)
+
+
+def align_crop_face_landmarks(img,
+                              landmarks,
+                              output_size,
+                              transform_size=None,
+                              enable_padding=True,
+                              return_inverse_affine=False,
+                              shrink_ratio=(1, 1)):
+    """Align and crop face with landmarks.
+
+    The output_size and transform_size are based on width. The height is
+    adjusted based on shrink_ratio_h/shring_ration_w.
+
+    Modified from:
+    https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
+
+    Args:
+        img (Numpy array): Input image.
+        landmarks (Numpy array): 5 or 68 or 98 landmarks.
+        output_size (int): Output face size.
+        transform_size (ing): Transform size. Usually the four time of
+            output_size.
+        enable_padding (float): Default: True.
+        shrink_ratio (float | tuple[float] | list[float]): Shring the whole
+            face for height and width (crop larger area). Default: (1, 1).
+
+    Returns:
+        (Numpy array): Cropped face.
+    """
+    lm_type = 'retinaface_5'  # Options: dlib_5, retinaface_5
+
+    if isinstance(shrink_ratio, (float, int)):
+        shrink_ratio = (shrink_ratio, shrink_ratio)
+    if transform_size is None:
+        transform_size = output_size * 4
+
+    # Parse landmarks
+    lm = np.array(landmarks)
+    if lm.shape[0] == 5 and lm_type == 'retinaface_5':
+        eye_left = lm[0]
+        eye_right = lm[1]
+        mouth_avg = (lm[3] + lm[4]) * 0.5
+    elif lm.shape[0] == 5 and lm_type == 'dlib_5':
+        lm_eye_left = lm[2:4]
+        lm_eye_right = lm[0:2]
+        eye_left = np.mean(lm_eye_left, axis=0)
+        eye_right = np.mean(lm_eye_right, axis=0)
+        mouth_avg = lm[4]
+    elif lm.shape[0] == 68:
+        lm_eye_left = lm[36:42]
+        lm_eye_right = lm[42:48]
+        eye_left = np.mean(lm_eye_left, axis=0)
+        eye_right = np.mean(lm_eye_right, axis=0)
+        mouth_avg = (lm[48] + lm[54]) * 0.5
+    elif lm.shape[0] == 98:
+        lm_eye_left = lm[60:68]
+        lm_eye_right = lm[68:76]
+        eye_left = np.mean(lm_eye_left, axis=0)
+        eye_right = np.mean(lm_eye_right, axis=0)
+        mouth_avg = (lm[76] + lm[82]) * 0.5
+
+    eye_avg = (eye_left + eye_right) * 0.5
+    eye_to_eye = eye_right - eye_left
+    eye_to_mouth = mouth_avg - eye_avg
+
+    # Get the oriented crop rectangle
+    # x: half width of the oriented crop rectangle
+    x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+    #  - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+    # norm with the hypotenuse: get the direction
+    x /= np.hypot(*x)  # get the hypotenuse of a right triangle
+    rect_scale = 1  # TODO: you can edit it to get larger rect
+    x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+    # y: half height of the oriented crop rectangle
+    y = np.flipud(x) * [-1, 1]
+
+    x *= shrink_ratio[1]  # width
+    y *= shrink_ratio[0]  # height
+
+    # c: center
+    c = eye_avg + eye_to_mouth * 0.1
+    # quad: (left_top, left_bottom, right_bottom, right_top)
+    quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+    # qsize: side length of the square
+    qsize = np.hypot(*x) * 2
+
+    quad_ori = np.copy(quad)
+    # Shrink, for large face
+    # TODO: do we really need shrink
+    shrink = int(np.floor(qsize / output_size * 0.5))
+    if shrink > 1:
+        h, w = img.shape[0:2]
+        rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
+        img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
+        quad /= shrink
+        qsize /= shrink
+
+    # Crop
+    h, w = img.shape[0:2]
+    border = max(int(np.rint(qsize * 0.1)), 3)
+    crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+            int(np.ceil(max(quad[:, 1]))))
+    crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
+    if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
+        img = img[crop[1]:crop[3], crop[0]:crop[2], :]
+        quad -= crop[0:2]
+
+    # Pad
+    # pad: (width_left, height_top, width_right, height_bottom)
+    h, w = img.shape[0:2]
+    pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+           int(np.ceil(max(quad[:, 1]))))
+    pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
+    if enable_padding and max(pad) > border - 4:
+        pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+        img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+        h, w = img.shape[0:2]
+        y, x, _ = np.ogrid[:h, :w, :1]
+        mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+                                           np.float32(w - 1 - x) / pad[2]),
+                          1.0 - np.minimum(np.float32(y) / pad[1],
+                                           np.float32(h - 1 - y) / pad[3]))
+        blur = int(qsize * 0.02)
+        if blur % 2 == 0:
+            blur += 1
+        blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
+
+        img = img.astype('float32')
+        img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+        img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+        img = np.clip(img, 0, 255)  # float32, [0, 255]
+        quad += pad[:2]
+
+    # Transform use cv2
+    h_ratio = shrink_ratio[0] / shrink_ratio[1]
+    dst_h, dst_w = int(transform_size * h_ratio), transform_size
+    template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+    # use cv2.LMEDS method for the equivalence to skimage transform
+    # ref: https://blog.csdn.net/yichxi/article/details/115827338
+    affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
+    cropped_face = cv2.warpAffine(
+        img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132))  # gray
+
+    if output_size < transform_size:
+        cropped_face = cv2.resize(
+            cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
+
+    if return_inverse_affine:
+        dst_h, dst_w = int(output_size * h_ratio), output_size
+        template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+        # use cv2.LMEDS method for the equivalence to skimage transform
+        # ref: https://blog.csdn.net/yichxi/article/details/115827338
+        affine_matrix = cv2.estimateAffinePartial2D(
+            quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
+        inverse_affine = cv2.invertAffineTransform(affine_matrix)
+    else:
+        inverse_affine = None
+    return cropped_face, inverse_affine
+
+
+def paste_face_back(img, face, inverse_affine):
+    h, w = img.shape[0:2]
+    face_h, face_w = face.shape[0:2]
+    inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
+    mask = np.ones((face_h, face_w, 3), dtype=np.float32)
+    inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
+    # remove the black borders
+    inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
+    inv_restored_remove_border = inv_mask_erosion * inv_restored
+    total_face_area = np.sum(inv_mask_erosion) // 3
+    # compute the fusion edge based on the area of face
+    w_edge = int(total_face_area**0.5) // 20
+    erosion_radius = w_edge * 2
+    inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+    blur_size = w_edge * 2
+    inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+    img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
+    # float32, [0, 255]
+    return img
+
+
+if __name__ == '__main__':
+    import os
+
+    from facelib.detection import init_detection_model
+    from facelib.utils.face_restoration_helper import get_largest_face
+
+    img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
+    img_name = os.splitext(os.path.basename(img_path))[0]
+
+    # initialize model
+    det_net = init_detection_model('retinaface_resnet50', half=False)
+    img_ori = cv2.imread(img_path)
+    h, w = img_ori.shape[0:2]
+    # if larger than 800, scale it
+    scale = max(h / 800, w / 800)
+    if scale > 1:
+        img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
+
+    with torch.no_grad():
+        bboxes = det_net.detect_faces(img, 0.97)
+    if scale > 1:
+        bboxes *= scale  # the score is incorrect
+    bboxes = get_largest_face(bboxes, h, w)[0]
+
+    landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
+
+    cropped_face, inverse_affine = align_crop_face_landmarks(
+        img_ori,
+        landmarks,
+        output_size=512,
+        transform_size=None,
+        enable_padding=True,
+        return_inverse_affine=True,
+        shrink_ratio=(1, 1))
+
+    cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
+    img = paste_face_back(img_ori, cropped_face, inverse_affine)
+    cv2.imwrite(f'tmp/{img_name}_back.png', img)
diff --git a/CodeFormer/facelib/utils/misc.py b/CodeFormer/facelib/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e2c0343f972d5bd5c735c5cfbf8b28bca6dd55
--- /dev/null
+++ b/CodeFormer/facelib/utils/misc.py
@@ -0,0 +1,174 @@
+import cv2
+import os
+import os.path as osp
+import numpy as np
+from PIL import Image
+import torch
+from torch.hub import download_url_to_file, get_dir
+from urllib.parse import urlparse
+# from basicsr.utils.download_util import download_file_from_google_drive
+# import gdown
+
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+def download_pretrained_models(file_ids, save_path_root):
+    os.makedirs(save_path_root, exist_ok=True)
+
+    for file_name, file_id in file_ids.items():
+        file_url = 'https://drive.google.com/uc?id='+file_id
+        save_path = osp.abspath(osp.join(save_path_root, file_name))
+        if osp.exists(save_path):
+            user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+            if user_response.lower() == 'y':
+                print(f'Covering {file_name} to {save_path}')
+                # gdown.download(file_url, save_path, quiet=False)
+                # download_file_from_google_drive(file_id, save_path)
+            elif user_response.lower() == 'n':
+                print(f'Skipping {file_name}')
+            else:
+                raise ValueError('Wrong input. Only accepts Y/N.')
+        else:
+            print(f'Downloading {file_name} to {save_path}')
+            # gdown.download(file_url, save_path, quiet=False)
+            # download_file_from_google_drive(file_id, save_path)
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+    """Write image to file.
+
+    Args:
+        img (ndarray): Image array to be written.
+        file_path (str): Image file path.
+        params (None or list): Same as opencv's :func:`imwrite` interface.
+        auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+            whether to create it automatically.
+
+    Returns:
+        bool: Successful or not.
+    """
+    if auto_mkdir:
+        dir_name = os.path.abspath(os.path.dirname(file_path))
+        os.makedirs(dir_name, exist_ok=True)
+    return cv2.imwrite(file_path, img, params)
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+    """Numpy array to tensor.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Input images.
+        bgr2rgb (bool): Whether to change bgr to rgb.
+        float32 (bool): Whether to change to float32.
+
+    Returns:
+        list[tensor] | tensor: Tensor images. If returned results only have
+            one element, just return tensor.
+    """
+
+    def _totensor(img, bgr2rgb, float32):
+        if img.shape[2] == 3 and bgr2rgb:
+            if img.dtype == 'float64':
+                img = img.astype('float32')
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = torch.from_numpy(img.transpose(2, 0, 1))
+        if float32:
+            img = img.float()
+        return img
+
+    if isinstance(imgs, list):
+        return [_totensor(img, bgr2rgb, float32) for img in imgs]
+    else:
+        return _totensor(imgs, bgr2rgb, float32)
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+    """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+    """
+    if model_dir is None:
+        hub_dir = get_dir()
+        model_dir = os.path.join(hub_dir, 'checkpoints')
+
+    os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
+
+    parts = urlparse(url)
+    filename = os.path.basename(parts.path)
+    if file_name is not None:
+        filename = file_name
+    cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
+    if not os.path.exists(cached_file):
+        print(f'Downloading: "{url}" to {cached_file}\n')
+        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+    return cached_file
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+    """Scan a directory to find the interested files.
+    Args:
+        dir_path (str): Path of the directory.
+        suffix (str | tuple(str), optional): File suffix that we are
+            interested in. Default: None.
+        recursive (bool, optional): If set to True, recursively scan the
+            directory. Default: False.
+        full_path (bool, optional): If set to True, include the dir_path.
+            Default: False.
+    Returns:
+        A generator for all the interested files with relative paths.
+    """
+
+    if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+        raise TypeError('"suffix" must be a string or tuple of strings')
+
+    root = dir_path
+
+    def _scandir(dir_path, suffix, recursive):
+        for entry in os.scandir(dir_path):
+            if not entry.name.startswith('.') and entry.is_file():
+                if full_path:
+                    return_path = entry.path
+                else:
+                    return_path = osp.relpath(entry.path, root)
+
+                if suffix is None:
+                    yield return_path
+                elif return_path.endswith(suffix):
+                    yield return_path
+            else:
+                if recursive:
+                    yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+                else:
+                    continue
+
+    return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def is_gray(img, threshold=10):
+    img = Image.fromarray(img)
+    if len(img.getbands()) == 1:
+        return True
+    img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
+    img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
+    img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
+    diff1 = (img1 - img2).var()
+    diff2 = (img2 - img3).var()
+    diff3 = (img3 - img1).var()
+    diff_sum = (diff1 + diff2 + diff3) / 3.0
+    if diff_sum <= threshold:
+        return True
+    else:
+        return False
+
+def rgb2gray(img, out_channel=3):
+    r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
+    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
+    if out_channel == 3:
+        gray = gray[:,:,np.newaxis].repeat(3, axis=2)
+    return gray
+
+def bgr2gray(img, out_channel=3):
+    b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
+    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
+    if out_channel == 3:
+        gray = gray[:,:,np.newaxis].repeat(3, axis=2)
+    return gray
diff --git a/CodeFormer/inference_codeformer.py b/CodeFormer/inference_codeformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdfe8b301cc7c20c2fb653618e379d243603a108
--- /dev/null
+++ b/CodeFormer/inference_codeformer.py
@@ -0,0 +1,189 @@
+# Modified by Shangchen Zhou from: https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
+import os
+import cv2
+import argparse
+import glob
+import torch
+from torchvision.transforms.functional import normalize
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.utils.download_util import load_file_from_url
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+import torch.nn.functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+pretrain_model_url = {
+    'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
+}
+
+def set_realesrgan():
+    if not torch.cuda.is_available():  # CPU
+        import warnings
+        warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
+                        'If you really want to use it, please modify the corresponding codes.',
+                        category=RuntimeWarning)
+        bg_upsampler = None
+    else:
+        from basicsr.archs.rrdbnet_arch import RRDBNet
+        from basicsr.utils.realesrgan_utils import RealESRGANer
+        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+        bg_upsampler = RealESRGANer(
+            scale=2,
+            model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
+            model=model,
+            tile=args.bg_tile,
+            tile_pad=40,
+            pre_pad=0,
+            half=True)  # need to set False in CPU mode
+    return bg_upsampler
+
+if __name__ == '__main__':
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
+    parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
+    parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
+    parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
+    parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
+    # large det_model: 'YOLOv5l', 'retinaface_resnet50'
+    # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
+    parser.add_argument('--detection_model', type=str, default='retinaface_resnet50')
+    parser.add_argument('--draw_box', action='store_true')
+    parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
+    parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
+    parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
+
+    args = parser.parse_args()
+
+    # ------------------------ input & output ------------------------
+    if args.test_path.endswith('/'):  # solve when path ends with /
+        args.test_path = args.test_path[:-1]
+
+    w = args.w
+    result_root = f'results/{os.path.basename(args.test_path)}_{w}'
+
+    # ------------------ set up background upsampler ------------------
+    if args.bg_upsampler == 'realesrgan':
+        bg_upsampler = set_realesrgan()
+    else:
+        bg_upsampler = None
+
+    # ------------------ set up face upsampler ------------------
+    if args.face_upsample:
+        if bg_upsampler is not None:
+            face_upsampler = bg_upsampler
+        else:
+            face_upsampler = set_realesrgan()
+    else:
+        face_upsampler = None
+
+    # ------------------ set up CodeFormer restorer -------------------
+    net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 
+                                            connect_list=['32', '64', '128', '256']).to(device)
+    
+    # ckpt_path = 'weights/CodeFormer/codeformer.pth'
+    ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], 
+                                    model_dir='weights/CodeFormer', progress=True, file_name=None)
+    checkpoint = torch.load(ckpt_path)['params_ema']
+    net.load_state_dict(checkpoint)
+    net.eval()
+
+    # ------------------ set up FaceRestoreHelper -------------------
+    # large det_model: 'YOLOv5l', 'retinaface_resnet50'
+    # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
+    if not args.has_aligned: 
+        print(f'Face detection model: {args.detection_model}')
+    if bg_upsampler is not None: 
+        print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
+    else:
+        print(f'Background upsampling: False, Face upsampling: {args.face_upsample}')
+
+    face_helper = FaceRestoreHelper(
+        args.upscale,
+        face_size=512,
+        crop_ratio=(1, 1),
+        det_model = args.detection_model,
+        save_ext='png',
+        use_parse=True,
+        device=device)
+
+    # -------------------- start to processing ---------------------
+    # scan all the jpg and png images
+    for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
+        # clean all the intermediate results to process the next image
+        face_helper.clean_all()
+        
+        img_name = os.path.basename(img_path)
+        print(f'Processing: {img_name}')
+        basename, ext = os.path.splitext(img_name)
+        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+
+        if args.has_aligned: 
+            # the input faces are already cropped and aligned
+            img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+            face_helper.cropped_faces = [img]
+        else:
+            face_helper.read_image(img)
+            # get face landmarks for each face
+            num_det_faces = face_helper.get_face_landmarks_5(
+                only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
+            print(f'\tdetect {num_det_faces} faces')
+            # align and warp each face
+            face_helper.align_warp_face()
+
+        # face restoration for each cropped face
+        for idx, cropped_face in enumerate(face_helper.cropped_faces):
+            # prepare data
+            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+            cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+            try:
+                with torch.no_grad():
+                    output = net(cropped_face_t, w=w, adain=True)[0]
+                    restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+                del output
+                torch.cuda.empty_cache()
+            except Exception as error:
+                print(f'\tFailed inference for CodeFormer: {error}')
+                restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+            restored_face = restored_face.astype('uint8')
+            face_helper.add_restored_face(restored_face)
+
+        # paste_back
+        if not args.has_aligned:
+            # upsample the background
+            if bg_upsampler is not None:
+                # Now only support RealESRGAN for upsampling background
+                bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
+            else:
+                bg_img = None
+            face_helper.get_inverse_affine(None)
+            # paste each restored face to the input image
+            if args.face_upsample and face_upsampler is not None: 
+                restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
+            else:
+                restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
+
+        # save faces
+        for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
+            # save cropped face
+            if not args.has_aligned: 
+                save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
+                imwrite(cropped_face, save_crop_path)
+            # save restored face
+            if args.has_aligned:
+                save_face_name = f'{basename}.png'
+            else:
+                save_face_name = f'{basename}_{idx:02d}.png'
+            save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
+            imwrite(restored_face, save_restore_path)
+
+        # save restored img
+        if not args.has_aligned and restored_img is not None:
+            save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
+            imwrite(restored_img, save_restore_path)
+
+    print(f'\nAll results are saved in {result_root}')
diff --git a/CodeFormer/requirements.txt b/CodeFormer/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f97dfde85ebe83708fc1f6f7234a0ef69f18bde5
--- /dev/null
+++ b/CodeFormer/requirements.txt
@@ -0,0 +1,20 @@
+addict
+future
+lmdb
+numpy
+opencv-python
+Pillow
+pyyaml
+requests
+scikit-image
+scipy
+tb-nightly
+torch>=1.7.1
+torchvision
+tqdm
+yapf
+lpips
+gdown # supports downloading the large file from Google Drive
+# cmake
+# dlib
+# conda install -c conda-forge dlib
\ No newline at end of file
diff --git a/CodeFormer/scripts/crop_align_face.py b/CodeFormer/scripts/crop_align_face.py
new file mode 100755
index 0000000000000000000000000000000000000000..31e66266ac0e5f818fa18b6409993151086bbc8b
--- /dev/null
+++ b/CodeFormer/scripts/crop_align_face.py
@@ -0,0 +1,192 @@
+"""
+brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
+author: lzhbrian (https://lzhbrian.me)
+link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
+date: 2020.1.5
+note: code is heavily borrowed from
+    https://github.com/NVlabs/ffhq-dataset
+    http://dlib.net/face_landmark_detection.py.html
+requirements:
+    conda install Pillow numpy scipy
+    conda install -c conda-forge dlib
+    # download face landmark model from:
+    # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+"""
+
+import cv2
+import dlib
+import glob
+import numpy as np
+import os
+import PIL
+import PIL.Image
+import scipy
+import scipy.ndimage
+import sys
+import argparse
+
+# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
+
+
+def get_landmark(filepath, only_keep_largest=True):
+    """get landmark with dlib
+    :return: np.array shape=(68, 2)
+    """
+    detector = dlib.get_frontal_face_detector()
+
+    img = dlib.load_rgb_image(filepath)
+    dets = detector(img, 1)
+
+    # Shangchen modified
+    print("Number of faces detected: {}".format(len(dets)))
+    if only_keep_largest:
+        print('Detect several faces and only keep the largest.')
+        face_areas = []
+        for k, d in enumerate(dets):
+            face_area = (d.right() - d.left()) * (d.bottom() - d.top())
+            face_areas.append(face_area)
+
+        largest_idx = face_areas.index(max(face_areas))
+        d = dets[largest_idx]
+        shape = predictor(img, d)
+        print("Part 0: {}, Part 1: {} ...".format(
+            shape.part(0), shape.part(1)))
+    else:
+        for k, d in enumerate(dets):
+            print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
+                k, d.left(), d.top(), d.right(), d.bottom()))
+            # Get the landmarks/parts for the face in box d.
+            shape = predictor(img, d)
+            print("Part 0: {}, Part 1: {} ...".format(
+                shape.part(0), shape.part(1)))
+
+    t = list(shape.parts())
+    a = []
+    for tt in t:
+        a.append([tt.x, tt.y])
+    lm = np.array(a)
+    # lm is a shape=(68,2) np.array
+    return lm
+
+def align_face(filepath, out_path):
+    """
+    :param filepath: str
+    :return: PIL Image
+    """
+    try:
+        lm = get_landmark(filepath)
+    except:
+        print('No landmark ...')
+        return
+
+    lm_chin = lm[0:17]  # left-right
+    lm_eyebrow_left = lm[17:22]  # left-right
+    lm_eyebrow_right = lm[22:27]  # left-right
+    lm_nose = lm[27:31]  # top-down
+    lm_nostrils = lm[31:36]  # top-down
+    lm_eye_left = lm[36:42]  # left-clockwise
+    lm_eye_right = lm[42:48]  # left-clockwise
+    lm_mouth_outer = lm[48:60]  # left-clockwise
+    lm_mouth_inner = lm[60:68]  # left-clockwise
+
+    # Calculate auxiliary vectors.
+    eye_left = np.mean(lm_eye_left, axis=0)
+    eye_right = np.mean(lm_eye_right, axis=0)
+    eye_avg = (eye_left + eye_right) * 0.5
+    eye_to_eye = eye_right - eye_left
+    mouth_left = lm_mouth_outer[0]
+    mouth_right = lm_mouth_outer[6]
+    mouth_avg = (mouth_left + mouth_right) * 0.5
+    eye_to_mouth = mouth_avg - eye_avg
+
+    # Choose oriented crop rectangle.
+    x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+    x /= np.hypot(*x)
+    x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+    y = np.flipud(x) * [-1, 1]
+    c = eye_avg + eye_to_mouth * 0.1
+    quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+    qsize = np.hypot(*x) * 2
+
+    # read image
+    img = PIL.Image.open(filepath)
+
+    output_size = 512
+    transform_size = 4096
+    enable_padding = False
+
+    # Shrink.
+    shrink = int(np.floor(qsize / output_size * 0.5))
+    if shrink > 1:
+        rsize = (int(np.rint(float(img.size[0]) / shrink)),
+                 int(np.rint(float(img.size[1]) / shrink)))
+        img = img.resize(rsize, PIL.Image.ANTIALIAS)
+        quad /= shrink
+        qsize /= shrink
+ 
+    # Crop.
+    border = max(int(np.rint(qsize * 0.1)), 3)
+    crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
+            int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
+    crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
+            min(crop[2] + border,
+                img.size[0]), min(crop[3] + border, img.size[1]))
+    if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+        img = img.crop(crop)
+        quad -= crop[0:2]
+
+    # Pad.
+    pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
+           int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
+    pad = (max(-pad[0] + border,
+               0), max(-pad[1] + border,
+                       0), max(pad[2] - img.size[0] + border,
+                               0), max(pad[3] - img.size[1] + border, 0))
+    if enable_padding and max(pad) > border - 4:
+        pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+        img = np.pad(
+            np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
+            'reflect')
+        h, w, _ = img.shape
+        y, x, _ = np.ogrid[:h, :w, :1]
+        mask = np.maximum(
+            1.0 -
+            np.minimum(np.float32(x) / pad[0],
+                       np.float32(w - 1 - x) / pad[2]), 1.0 -
+            np.minimum(np.float32(y) / pad[1],
+                       np.float32(h - 1 - y) / pad[3]))
+        blur = qsize * 0.02
+        img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
+                img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+        img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+        img = PIL.Image.fromarray(
+            np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+        quad += pad[:2]
+
+    img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
+                        (quad + 0.5).flatten(), PIL.Image.BILINEAR)
+
+    if output_size < transform_size:
+        img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
+
+    # Save aligned image.
+    print('saveing: ', out_path)
+    img.save(out_path)
+
+    return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--in_dir', type=str, default='./inputs/whole_imgs')
+    parser.add_argument('--out_dir', type=str, default='./inputs/cropped_faces')
+    args = parser.parse_args()
+
+    img_list = sorted(glob.glob(f'{args.in_dir}/*.png'))
+    img_list = sorted(img_list)
+
+    for in_path in img_list:
+        out_path = os.path.join(args.out_dir, in_path.split("/")[-1])        
+        out_path = out_path.replace('.jpg', '.png')
+        size_ = align_face(in_path, out_path)
\ No newline at end of file
diff --git a/CodeFormer/scripts/download_pretrained_models.py b/CodeFormer/scripts/download_pretrained_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa6e8ca14ea91c89a318e85d9f182eb7d1bf025
--- /dev/null
+++ b/CodeFormer/scripts/download_pretrained_models.py
@@ -0,0 +1,40 @@
+import argparse
+import os
+from os import path as osp
+
+from basicsr.utils.download_util import load_file_from_url
+
+
+def download_pretrained_models(method, file_urls):
+    save_path_root = f'./weights/{method}'
+    os.makedirs(save_path_root, exist_ok=True)
+
+    for file_name, file_url in file_urls.items():
+        save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        'method',
+        type=str,
+        help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
+    args = parser.parse_args()
+
+    file_urls = {
+        'CodeFormer': {
+            'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+        },
+        'facelib': {
+            # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
+            'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
+            'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
+        }
+    }
+
+    if args.method == 'all':
+        for method in file_urls.keys():
+            download_pretrained_models(method, file_urls[method])
+    else:
+        download_pretrained_models(args.method, file_urls[args.method])
\ No newline at end of file
diff --git a/CodeFormer/scripts/download_pretrained_models_from_gdrive.py b/CodeFormer/scripts/download_pretrained_models_from_gdrive.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df5be6fc260394ee9bbd0a7ae377e2ca657fe83
--- /dev/null
+++ b/CodeFormer/scripts/download_pretrained_models_from_gdrive.py
@@ -0,0 +1,60 @@
+import argparse
+import os
+from os import path as osp
+
+# from basicsr.utils.download_util import download_file_from_google_drive
+import gdown
+
+
+def download_pretrained_models(method, file_ids):
+    save_path_root = f'./weights/{method}'
+    os.makedirs(save_path_root, exist_ok=True)
+
+    for file_name, file_id in file_ids.items():
+        file_url = 'https://drive.google.com/uc?id='+file_id
+        save_path = osp.abspath(osp.join(save_path_root, file_name))
+        if osp.exists(save_path):
+            user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+            if user_response.lower() == 'y':
+                print(f'Covering {file_name} to {save_path}')
+                gdown.download(file_url, save_path, quiet=False)
+                # download_file_from_google_drive(file_id, save_path)
+            elif user_response.lower() == 'n':
+                print(f'Skipping {file_name}')
+            else:
+                raise ValueError('Wrong input. Only accepts Y/N.')
+        else:
+            print(f'Downloading {file_name} to {save_path}')
+            gdown.download(file_url, save_path, quiet=False)
+            # download_file_from_google_drive(file_id, save_path)
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        'method',
+        type=str,
+        help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
+    args = parser.parse_args()
+
+    # file name: file id
+    # 'dlib': {
+    #     'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
+    #     'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
+    #     'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
+    # }
+    file_ids = {
+        'CodeFormer': {
+            'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
+        },
+        'facelib': {
+            'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
+            'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
+        }
+    }
+
+    if args.method == 'all':
+        for method in file_ids.keys():
+            download_pretrained_models(method, file_ids[method])
+    else:
+        download_pretrained_models(args.method, file_ids[args.method])
\ No newline at end of file
diff --git a/CodeFormer/weights/README.md b/CodeFormer/weights/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..67ad334bd672eeb9f82813cd54e8885331bbb2f2
--- /dev/null
+++ b/CodeFormer/weights/README.md
@@ -0,0 +1,3 @@
+# Weights
+
+Put the downloaded pre-trained models to this folder.
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e03522a6d228292d24f8b99a900e3107ec3abe
--- /dev/null
+++ b/app.py
@@ -0,0 +1,260 @@
+"""
+This file is used for deploying hugging face demo:
+https://huggingface.co/spaces/sczhou/CodeFormer
+"""
+
+import sys
+sys.path.append('CodeFormer')
+import os
+import cv2
+import torch
+import torch.nn.functional as F
+import gradio as gr
+
+from torchvision.transforms.functional import normalize
+
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.utils.download_util import load_file_from_url
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+from facelib.utils.misc import is_gray
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils.realesrgan_utils import RealESRGANer
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+os.system("pip freeze")
+
+pretrain_model_url = {
+    'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
+    'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
+    'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
+    'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
+}
+# download weights
+if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'):
+    load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None)
+if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'):
+    load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
+if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'):
+    load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
+if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'):
+    load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None)
+
+# download images
+torch.hub.download_url_to_file(
+    'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png',
+    '01.png')
+torch.hub.download_url_to_file(
+    'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg',
+    '02.jpg')
+torch.hub.download_url_to_file(
+    'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg',
+    '03.jpg')
+torch.hub.download_url_to_file(
+    'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg',
+    '04.jpg')
+torch.hub.download_url_to_file(
+    'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
+    '05.jpg')
+
+def imread(img_path):
+    img = cv2.imread(img_path)
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    return img
+
+# set enhancer with RealESRGAN
+def set_realesrgan():
+    half = True if torch.cuda.is_available() else False
+    model = RRDBNet(
+        num_in_ch=3,
+        num_out_ch=3,
+        num_feat=64,
+        num_block=23,
+        num_grow_ch=32,
+        scale=2,
+    )
+    upsampler = RealESRGANer(
+        scale=2,
+        model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
+        model=model,
+        tile=400,
+        tile_pad=40,
+        pre_pad=0,
+        half=half,
+    )
+    return upsampler
+
+upsampler = set_realesrgan()
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
+    dim_embd=512,
+    codebook_size=1024,
+    n_head=8,
+    n_layers=9,
+    connect_list=["32", "64", "128", "256"],
+).to(device)
+ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
+checkpoint = torch.load(ckpt_path)["params_ema"]
+codeformer_net.load_state_dict(checkpoint)
+codeformer_net.eval()
+
+os.makedirs('output', exist_ok=True)
+
+def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
+    """Run a single prediction on the model"""
+    # take the default setting for the demo
+    has_aligned = False
+    only_center_face = False
+    draw_box = False
+    detection_model = "retinaface_resnet50"
+
+    upscale = int(upscale) # covert type to int
+    face_helper = FaceRestoreHelper(
+        upscale,
+        face_size=512,
+        crop_ratio=(1, 1),
+        det_model=detection_model,
+        save_ext="png",
+        use_parse=True,
+        device=device,
+    )
+    bg_upsampler = upsampler if background_enhance else None
+    face_upsampler = upsampler if face_upsample else None
+
+    img = cv2.imread(str(image), cv2.IMREAD_COLOR)
+
+    if has_aligned:
+        # the input faces are already cropped and aligned
+        img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+        face_helper.is_gray = is_gray(img, threshold=5)
+        if face_helper.is_gray:
+            print('Grayscale input: True')
+        face_helper.cropped_faces = [img]
+    else:
+        face_helper.read_image(img)
+        # get face landmarks for each face
+        num_det_faces = face_helper.get_face_landmarks_5(
+          only_center_face=only_center_face, resize=640, eye_dist_threshold=5
+        )
+        print(f"\tdetect {num_det_faces} faces")
+        # align and warp each face
+        face_helper.align_warp_face()
+
+    # face restoration for each cropped face
+    for idx, cropped_face in enumerate(face_helper.cropped_faces):
+        # prepare data
+        cropped_face_t = img2tensor(
+            cropped_face / 255.0, bgr2rgb=True, float32=True
+        )
+        normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+        cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+        try:
+            with torch.no_grad():
+                output = codeformer_net(
+                    cropped_face_t, w=codeformer_fidelity, adain=True
+                )[0]
+                restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+            del output
+            torch.cuda.empty_cache()
+        except Exception as error:
+            print(f"\tFailed inference for CodeFormer: {error}")
+            restored_face = tensor2img(
+                cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
+            )
+
+        restored_face = restored_face.astype("uint8")
+        face_helper.add_restored_face(restored_face)
+
+    # paste_back
+    if not has_aligned:
+        # upsample the background
+        if bg_upsampler is not None:
+            # Now only support RealESRGAN for upsampling background
+            bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
+        else:
+            bg_img = None
+        face_helper.get_inverse_affine(None)
+        # paste each restored face to the input image
+        if face_upsample and face_upsampler is not None:
+            restored_img = face_helper.paste_faces_to_input_image(
+                upsample_img=bg_img,
+                draw_box=draw_box,
+                face_upsampler=face_upsampler,
+            )
+        else:
+            restored_img = face_helper.paste_faces_to_input_image(
+                upsample_img=bg_img, draw_box=draw_box
+            )
+
+    # save restored img
+    save_path = f'output/out.png'
+    imwrite(restored_img, str(save_path))
+
+    restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
+    return restored_img, save_path
+
+
+
+title = "CodeFormer: Robust Face Restoration and Enhancement Network"
+description = r"""<center><img src='https://user-images.githubusercontent.com/14334509/189166076-94bb2cac-4f4e-40fb-a69f-66709e3d98f5.png' alt='CodeFormer logo'></center>
+<b>Official Gradio demo</b> for <a href='https://github.com/sczhou/CodeFormer' target='_blank'><b>Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)</b></a>.<br>
+🔥 CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.<br>
+🤗 Try CodeFormer for improved stable-diffusion generation!<br>
+"""
+article = r"""
+If CodeFormer is helpful, please help to ⭐ the <a href='https://github.com/sczhou/CodeFormer' target='_blank'>Github Repo</a>. Thanks! 
+[![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer)
+
+---
+
+📝 **Citation**
+
+If our work is useful for your research, please consider citing:
+```bibtex
+@inproceedings{zhou2022codeformer,
+    author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
+    title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
+    booktitle = {NeurIPS},
+    year = {2022}
+}
+```
+
+📋 **License**
+
+This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">S-Lab License 1.0</a>. 
+Redistribution and use for non-commercial purposes should follow this license.
+
+📧 **Contact**
+
+If you have any questions, please feel free to reach me out at <b>shangchenzhou@gmail.com</b>.
+
+![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
+"""
+
+demo = gr.Interface(
+    inference, [
+        gr.inputs.Image(type="filepath", label="Input"),
+        gr.inputs.Checkbox(default=True, label="Background_Enhance"),
+        gr.inputs.Checkbox(default=True, label="Face_Upsample"),
+        gr.inputs.Number(default=2, label="Rescaling_Factor"),
+        gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity: 0 for better quality, 1 for better identity')
+    ], [
+        gr.outputs.Image(type="numpy", label="Output"),
+        gr.outputs.File(label="Download the output")
+    ],
+    title=title,
+    description=description,
+    article=article,       
+    examples=[
+        ['01.png', True, True, 2, 0.7],
+        ['02.jpg', True, True, 2, 0.7],
+        ['03.jpg', True, True, 2, 0.7],
+        ['04.jpg', True, True, 2, 0.1],
+        ['05.jpg', True, True, 2, 0.1]
+      ]
+    ).launch()
+
+demo.queue(concurrency_count=4)
+demo.launch()
\ No newline at end of file
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4649f711d1c528342fa2dc4bd39ab6730af6dbde
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,3 @@
+ffmpeg
+libsm6
+libxext6
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f97dfde85ebe83708fc1f6f7234a0ef69f18bde5
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,20 @@
+addict
+future
+lmdb
+numpy
+opencv-python
+Pillow
+pyyaml
+requests
+scikit-image
+scipy
+tb-nightly
+torch>=1.7.1
+torchvision
+tqdm
+yapf
+lpips
+gdown # supports downloading the large file from Google Drive
+# cmake
+# dlib
+# conda install -c conda-forge dlib
\ No newline at end of file