Spaces:
Runtime error
Runtime error
initial commit
Browse files- .gitignore +54 -0
- configs/anonymizers/FB_cse.py +28 -0
- configs/anonymizers/FB_cse_mask.py +29 -0
- configs/anonymizers/FB_cse_mask_face.py +29 -0
- configs/anonymizers/face.py +18 -0
- configs/anonymizers/market1501/blackout.py +8 -0
- configs/anonymizers/market1501/person.py +6 -0
- configs/anonymizers/market1501/pixelation16.py +8 -0
- configs/anonymizers/market1501/pixelation8.py +8 -0
- configs/datasets/coco_cse.py +69 -0
- configs/datasets/fdf128.py +24 -0
- configs/datasets/fdf256.py +69 -0
- configs/datasets/fdh.py +89 -0
- configs/datasets/utils.py +12 -0
- configs/defaults.py +45 -0
- configs/discriminators/sg2_discriminator.py +42 -0
- configs/fdf/stylegan.py +14 -0
- configs/fdf/stylegan_fdf128.py +13 -0
- configs/fdh/styleganL.py +16 -0
- configs/fdh/styleganL_nocse.py +14 -0
- configs/generators/stylegan_unet.py +22 -0
- multi_app.py +204 -0
- sg3_torch_utils/LICENSE.txt +97 -0
- sg3_torch_utils/__init__.py +9 -0
- sg3_torch_utils/custom_ops.py +126 -0
- sg3_torch_utils/misc.py +172 -0
- sg3_torch_utils/ops/__init__.py +9 -0
- sg3_torch_utils/ops/bias_act.cpp +99 -0
- sg3_torch_utils/ops/bias_act.cu +173 -0
- sg3_torch_utils/ops/bias_act.h +38 -0
- sg3_torch_utils/ops/bias_act.py +215 -0
- sg3_torch_utils/ops/conv2d_gradfix.py +175 -0
- sg3_torch_utils/ops/conv2d_resample.py +142 -0
- sg3_torch_utils/ops/fma.py +63 -0
- sg3_torch_utils/ops/grid_sample_gradfix.py +88 -0
- sg3_torch_utils/ops/upfirdn2d.cpp +103 -0
- sg3_torch_utils/ops/upfirdn2d.cu +350 -0
- sg3_torch_utils/ops/upfirdn2d.h +59 -0
- sg3_torch_utils/ops/upfirdn2d.py +388 -0
- stylemc.py +295 -0
.gitignore
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FILES
|
2 |
+
*.yaml
|
3 |
+
*.pkl
|
4 |
+
*.flist
|
5 |
+
*.zip
|
6 |
+
*.out
|
7 |
+
*.npy
|
8 |
+
*.gz
|
9 |
+
*.ckpt
|
10 |
+
*.pth
|
11 |
+
*.log
|
12 |
+
*.pyc
|
13 |
+
*.csv
|
14 |
+
*.yml
|
15 |
+
*.ods
|
16 |
+
*.ods#
|
17 |
+
*.json
|
18 |
+
build_docker.sh
|
19 |
+
|
20 |
+
# Images / Videos
|
21 |
+
#*.png
|
22 |
+
#*.jpg
|
23 |
+
*.jpeg
|
24 |
+
*.m4a
|
25 |
+
*.mkv
|
26 |
+
*.mp4
|
27 |
+
|
28 |
+
# Directories created by inpaintron
|
29 |
+
.cache/
|
30 |
+
test_examples/
|
31 |
+
.vscode
|
32 |
+
__pycache__
|
33 |
+
.debug/
|
34 |
+
**/.ipynb_checkpoints/**
|
35 |
+
outputs/
|
36 |
+
|
37 |
+
|
38 |
+
# From pip setup
|
39 |
+
build/
|
40 |
+
*.egg-info
|
41 |
+
*.egg
|
42 |
+
.npm/
|
43 |
+
|
44 |
+
# From dockerfile
|
45 |
+
.bash_history
|
46 |
+
.viminfo
|
47 |
+
.local/
|
48 |
+
*.pickle
|
49 |
+
*.onnx
|
50 |
+
|
51 |
+
|
52 |
+
sbatch_files/
|
53 |
+
figures/
|
54 |
+
image_dump/
|
configs/anonymizers/FB_cse.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.person_detector import CSEPersonDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
from dp2.generator.dummy_generators import MaskOutGenerator
|
6 |
+
|
7 |
+
|
8 |
+
maskout_G = L(MaskOutGenerator)(noise="constant")
|
9 |
+
|
10 |
+
detector = L(CSEPersonDetector)(
|
11 |
+
mask_rcnn_cfg=dict(),
|
12 |
+
cse_cfg=dict(),
|
13 |
+
cse_post_process_cfg=dict(
|
14 |
+
target_imsize=(288, 160),
|
15 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
16 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
17 |
+
iou_combine_threshold=0.4,
|
18 |
+
dilation_percentage=0.02,
|
19 |
+
normalize_embedding=False
|
20 |
+
),
|
21 |
+
score_threshold=0.3,
|
22 |
+
cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
|
23 |
+
)
|
24 |
+
|
25 |
+
anonymizer = L(Anonymizer)(
|
26 |
+
detector="${detector}",
|
27 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
28 |
+
)
|
configs/anonymizers/FB_cse_mask.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.person_detector import CSEPersonDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
from dp2.generator.dummy_generators import MaskOutGenerator
|
6 |
+
|
7 |
+
|
8 |
+
maskout_G = L(MaskOutGenerator)(noise="constant")
|
9 |
+
|
10 |
+
detector = L(CSEPersonDetector)(
|
11 |
+
mask_rcnn_cfg=dict(),
|
12 |
+
cse_cfg=dict(),
|
13 |
+
cse_post_process_cfg=dict(
|
14 |
+
target_imsize=(288, 160),
|
15 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
16 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
17 |
+
iou_combine_threshold=0.4,
|
18 |
+
dilation_percentage=0.02,
|
19 |
+
normalize_embedding=False
|
20 |
+
),
|
21 |
+
score_threshold=0.3,
|
22 |
+
cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
|
23 |
+
)
|
24 |
+
|
25 |
+
anonymizer = L(Anonymizer)(
|
26 |
+
detector="${detector}",
|
27 |
+
person_G_cfg="configs/fdh/styleganL_nocse.py",
|
28 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
29 |
+
)
|
configs/anonymizers/FB_cse_mask_face.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
detector = L(CSeMaskFaceDetector)(
|
7 |
+
mask_rcnn_cfg=dict(),
|
8 |
+
face_detector_cfg=dict(),
|
9 |
+
face_post_process_cfg=dict(target_imsize=(256, 256)),
|
10 |
+
cse_cfg=dict(),
|
11 |
+
cse_post_process_cfg=dict(
|
12 |
+
target_imsize=(288, 160),
|
13 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
14 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
15 |
+
iou_combine_threshold=0.4,
|
16 |
+
dilation_percentage=0.02,
|
17 |
+
normalize_embedding=False
|
18 |
+
),
|
19 |
+
score_threshold=0.3,
|
20 |
+
cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
|
21 |
+
)
|
22 |
+
|
23 |
+
anonymizer = L(Anonymizer)(
|
24 |
+
detector="${detector}",
|
25 |
+
face_G_cfg="configs/fdf/stylegan.py",
|
26 |
+
person_G_cfg="configs/fdh/styleganL_nocse.py",
|
27 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
28 |
+
car_G_cfg="configs/generators/dummy/pixelation8.py"
|
29 |
+
)
|
configs/anonymizers/face.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.face_detector import FaceDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
|
7 |
+
detector = L(FaceDetector)(
|
8 |
+
face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
|
9 |
+
face_post_process_cfg=dict(target_imsize=(256, 256)),
|
10 |
+
score_threshold=0.3,
|
11 |
+
cache_directory=common.output_dir.joinpath("face_detection_cache")
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
anonymizer = L(Anonymizer)(
|
16 |
+
detector="${detector}",
|
17 |
+
face_G_cfg="configs/fdf/stylegan.py",
|
18 |
+
)
|
configs/anonymizers/market1501/blackout.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
|
configs/anonymizers/market1501/person.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
configs/anonymizers/market1501/pixelation16.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
|
configs/anonymizers/market1501/pixelation8.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
|
configs/datasets/coco_cse.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets import CocoCSE
|
7 |
+
from dp2.data.build import get_dataloader
|
8 |
+
from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
|
9 |
+
from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
|
10 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
11 |
+
from .utils import final_eval_fn
|
12 |
+
|
13 |
+
|
14 |
+
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
15 |
+
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
16 |
+
data_dir = Path(dataset_base_dir, "coco_cse")
|
17 |
+
data = dict(
|
18 |
+
imsize=(288, 160),
|
19 |
+
im_channels=3,
|
20 |
+
semantic_nc=26,
|
21 |
+
cse_nc=16,
|
22 |
+
train=dict(
|
23 |
+
dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
|
24 |
+
loader=L(get_dataloader)(
|
25 |
+
shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
|
26 |
+
batch_size="${train.batch_size}",
|
27 |
+
dataset="${..dataset}",
|
28 |
+
infinite=True,
|
29 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
30 |
+
L(ToFloat)(),
|
31 |
+
L(StyleGANAugmentPipe)(
|
32 |
+
rotate=0.5, rotate_max=.05,
|
33 |
+
xint=.5, xint_max=0.05,
|
34 |
+
scale=.5, scale_std=.05,
|
35 |
+
aniso=0.5, aniso_std=.05,
|
36 |
+
xfrac=.5, xfrac_std=.05,
|
37 |
+
brightness=.5, brightness_std=.05,
|
38 |
+
contrast=.5, contrast_std=.1,
|
39 |
+
hue=.5, hue_max=.05,
|
40 |
+
saturation=.5, saturation_std=.5,
|
41 |
+
imgfilter=.5, imgfilter_std=.1),
|
42 |
+
L(RandomHorizontalFlip)(p=0.5),
|
43 |
+
L(CreateEmbedding)(),
|
44 |
+
L(Resize)(size="${data.imsize}"),
|
45 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
46 |
+
L(CreateCondition)(),
|
47 |
+
])
|
48 |
+
)
|
49 |
+
),
|
50 |
+
val=dict(
|
51 |
+
dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
|
52 |
+
loader=L(get_dataloader)(
|
53 |
+
shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
|
54 |
+
batch_size="${train.batch_size}",
|
55 |
+
dataset="${..dataset}",
|
56 |
+
infinite=False,
|
57 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
58 |
+
L(ToFloat)(),
|
59 |
+
L(CreateEmbedding)(),
|
60 |
+
L(Resize)(size="${data.imsize}"),
|
61 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
62 |
+
L(CreateCondition)(),
|
63 |
+
])
|
64 |
+
)
|
65 |
+
),
|
66 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
67 |
+
train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
|
68 |
+
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
|
69 |
+
)
|
configs/datasets/fdf128.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from functools import partial
|
3 |
+
from dp2.data.datasets.fdf import FDFDataset
|
4 |
+
from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn
|
5 |
+
|
6 |
+
data_dir = Path(dataset_base_dir, "fdf")
|
7 |
+
data.train.dataset.dirpath = data_dir.joinpath("train")
|
8 |
+
data.val.dataset.dirpath = data_dir.joinpath("val")
|
9 |
+
data.imsize = (128, 128)
|
10 |
+
|
11 |
+
|
12 |
+
data.train_evaluation_fn = partial(
|
13 |
+
final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
|
14 |
+
data.evaluation_fn = partial(
|
15 |
+
final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
|
16 |
+
|
17 |
+
data.train.dataset.update(
|
18 |
+
_target_ = FDFDataset,
|
19 |
+
imsize="${data.imsize}"
|
20 |
+
)
|
21 |
+
data.val.dataset.update(
|
22 |
+
_target_ = FDFDataset,
|
23 |
+
imsize="${data.imsize}"
|
24 |
+
)
|
configs/datasets/fdf256.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets.fdf import FDF256Dataset
|
7 |
+
from dp2.data.build import get_dataloader
|
8 |
+
from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
|
9 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
10 |
+
from dp2.metrics.fid_clip import compute_fid_clip
|
11 |
+
from dp2.metrics.ppl import calculate_ppl
|
12 |
+
from .utils import final_eval_fn
|
13 |
+
|
14 |
+
|
15 |
+
def final_eval_fn(*args, **kwargs):
|
16 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
17 |
+
result2 = compute_fid_clip(*args, **kwargs)
|
18 |
+
assert all(key not in result for key in result2)
|
19 |
+
result.update(result2)
|
20 |
+
result3 = calculate_ppl(*args, **kwargs,)
|
21 |
+
assert all(key not in result for key in result3)
|
22 |
+
result.update(result3)
|
23 |
+
return result
|
24 |
+
|
25 |
+
|
26 |
+
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
27 |
+
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
28 |
+
data_dir = Path(dataset_base_dir, "fdf256")
|
29 |
+
data = dict(
|
30 |
+
imsize=(256, 256),
|
31 |
+
im_channels=3,
|
32 |
+
semantic_nc=None,
|
33 |
+
cse_nc=None,
|
34 |
+
n_keypoints=None,
|
35 |
+
train=dict(
|
36 |
+
dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
|
37 |
+
loader=L(get_dataloader)(
|
38 |
+
shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
|
39 |
+
batch_size="${train.batch_size}",
|
40 |
+
dataset="${..dataset}",
|
41 |
+
infinite=True,
|
42 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
43 |
+
L(ToFloat)(),
|
44 |
+
L(RandomHorizontalFlip)(p=0.5),
|
45 |
+
L(Resize)(size="${data.imsize}"),
|
46 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
47 |
+
L(CreateCondition)(),
|
48 |
+
])
|
49 |
+
)
|
50 |
+
),
|
51 |
+
val=dict(
|
52 |
+
dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
|
53 |
+
loader=L(get_dataloader)(
|
54 |
+
shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
|
55 |
+
batch_size="${train.batch_size}",
|
56 |
+
dataset="${..dataset}",
|
57 |
+
infinite=False,
|
58 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
59 |
+
L(ToFloat)(),
|
60 |
+
L(Resize)(size="${data.imsize}"),
|
61 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
62 |
+
L(CreateCondition)(),
|
63 |
+
])
|
64 |
+
)
|
65 |
+
),
|
66 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
67 |
+
train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "fdf_val_train")),
|
68 |
+
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
|
69 |
+
)
|
configs/datasets/fdh.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets.fdh import get_dataloader_fdh_wds
|
7 |
+
from dp2.data.utils import get_coco_flipmap
|
8 |
+
from dp2.data.transforms.transforms import (
|
9 |
+
Normalize,
|
10 |
+
ToFloat,
|
11 |
+
CreateCondition,
|
12 |
+
RandomHorizontalFlip,
|
13 |
+
CreateEmbedding,
|
14 |
+
)
|
15 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
16 |
+
from dp2.metrics.fid_clip import compute_fid_clip
|
17 |
+
from .utils import final_eval_fn
|
18 |
+
|
19 |
+
|
20 |
+
def train_eval_fn(*args, **kwargs):
|
21 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
22 |
+
result2 = compute_fid_clip(*args, **kwargs)
|
23 |
+
assert all(key not in result for key in result2)
|
24 |
+
result.update(result2)
|
25 |
+
return result
|
26 |
+
|
27 |
+
|
28 |
+
dataset_base_dir = (
|
29 |
+
os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
30 |
+
)
|
31 |
+
metrics_cache = (
|
32 |
+
os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
33 |
+
)
|
34 |
+
data_dir = Path(dataset_base_dir, "fdh")
|
35 |
+
data = dict(
|
36 |
+
imsize=(288, 160),
|
37 |
+
im_channels=3,
|
38 |
+
cse_nc=16,
|
39 |
+
n_keypoints=17,
|
40 |
+
train=dict(
|
41 |
+
loader=L(get_dataloader_fdh_wds)(
|
42 |
+
path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
|
43 |
+
batch_size="${train.batch_size}",
|
44 |
+
num_workers=6,
|
45 |
+
transform=L(torch.nn.Sequential)(
|
46 |
+
L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
|
47 |
+
),
|
48 |
+
gpu_transform=L(torch.nn.Sequential)(
|
49 |
+
L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
|
50 |
+
L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
|
51 |
+
L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
|
52 |
+
L(CreateCondition)(),
|
53 |
+
),
|
54 |
+
infinite=True,
|
55 |
+
shuffle=True,
|
56 |
+
partial_batches=False,
|
57 |
+
load_embedding=True,
|
58 |
+
)
|
59 |
+
),
|
60 |
+
val=dict(
|
61 |
+
loader=L(get_dataloader_fdh_wds)(
|
62 |
+
path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
|
63 |
+
batch_size="${train.batch_size}",
|
64 |
+
num_workers=6,
|
65 |
+
transform=None,
|
66 |
+
gpu_transform=L(torch.nn.Sequential)(
|
67 |
+
L(ToFloat)(keys=["img", "mask", "E_mask", "maskrcnn_mask"], norm=False),
|
68 |
+
L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
|
69 |
+
L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
|
70 |
+
L(CreateCondition)(),
|
71 |
+
),
|
72 |
+
infinite=False,
|
73 |
+
shuffle=False,
|
74 |
+
partial_batches=True,
|
75 |
+
load_embedding=True,
|
76 |
+
)
|
77 |
+
),
|
78 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
79 |
+
train_evaluation_fn=functools.partial(
|
80 |
+
train_eval_fn,
|
81 |
+
cache_directory=Path(metrics_cache, "fdh_v7_train"),
|
82 |
+
data_len=int(30e3),
|
83 |
+
),
|
84 |
+
evaluation_fn=functools.partial(
|
85 |
+
final_eval_fn,
|
86 |
+
cache_directory=Path(metrics_cache, "fdh_v6_val"),
|
87 |
+
data_len=int(30e3),
|
88 |
+
),
|
89 |
+
)
|
configs/datasets/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.metrics.ppl import calculate_ppl
|
2 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
3 |
+
from dp2.metrics.fid_clip import compute_fid_clip
|
4 |
+
|
5 |
+
|
6 |
+
def final_eval_fn(*args, **kwargs):
|
7 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
8 |
+
result2 = calculate_ppl(*args, **kwargs,)
|
9 |
+
result2 = compute_fid_clip(*args, **kwargs)
|
10 |
+
assert all(key not in result for key in result2)
|
11 |
+
result.update(result2)
|
12 |
+
return result
|
configs/defaults.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
|
7 |
+
PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
|
8 |
+
else:
|
9 |
+
PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
|
10 |
+
if "BASE_OUTPUT_DIR" in os.environ:
|
11 |
+
BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
|
12 |
+
else:
|
13 |
+
BASE_OUTPUT_DIR = pathlib.Path("outputs")
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
common = dict(
|
18 |
+
logger_backend=["wandb", "stdout", "json", "image_dumper"],
|
19 |
+
wandb_project="fba_test",
|
20 |
+
output_dir=BASE_OUTPUT_DIR,
|
21 |
+
experiment_name=None, # Optional experiment name to show on wandb
|
22 |
+
)
|
23 |
+
|
24 |
+
train = dict(
|
25 |
+
batch_size=32,
|
26 |
+
seed=0,
|
27 |
+
ims_per_log=1024,
|
28 |
+
ims_per_val=int(200e3),
|
29 |
+
max_images_to_train=int(12e6),
|
30 |
+
amp=dict(
|
31 |
+
enabled=True,
|
32 |
+
scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
|
33 |
+
scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
|
34 |
+
),
|
35 |
+
fp16_ddp_accumulate=False, # All gather gradients in fp16?
|
36 |
+
broadcast_buffers=False,
|
37 |
+
bias_act_plugin_enabled=True,
|
38 |
+
grid_sample_gradfix_enabled=True,
|
39 |
+
conv2d_gradfix_enabled=False,
|
40 |
+
channels_last=False,
|
41 |
+
)
|
42 |
+
|
43 |
+
# exponential moving average
|
44 |
+
EMA = dict(rampup=0.05)
|
45 |
+
|
configs/discriminators/sg2_discriminator.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from dp2.discriminator import SG2Discriminator
|
3 |
+
import torch
|
4 |
+
from dp2.loss import StyleGAN2Loss
|
5 |
+
|
6 |
+
|
7 |
+
discriminator = L(SG2Discriminator)(
|
8 |
+
imsize="${data.imsize}",
|
9 |
+
im_channels="${data.im_channels}",
|
10 |
+
min_fmap_resolution=4,
|
11 |
+
max_cnum_mul=8,
|
12 |
+
cnum=80,
|
13 |
+
input_condition=True,
|
14 |
+
conv_clamp=256,
|
15 |
+
input_cse=False,
|
16 |
+
cse_nc="${data.cse_nc}"
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
loss_fnc = L(StyleGAN2Loss)(
|
21 |
+
lazy_regularization=True,
|
22 |
+
lazy_reg_interval=16,
|
23 |
+
r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
|
24 |
+
EP_lambd=0.001,
|
25 |
+
pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
|
26 |
+
)
|
27 |
+
|
28 |
+
def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
|
29 |
+
if lazy_regularization:
|
30 |
+
# From Analyzing and improving the image quality of stylegan, CVPR 2020
|
31 |
+
c = lazy_reg_interval / (lazy_reg_interval + 1)
|
32 |
+
betas = [beta ** c for beta in betas]
|
33 |
+
lr *= c
|
34 |
+
print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
|
35 |
+
return type(lr=lr, betas=betas, **kwargs)
|
36 |
+
|
37 |
+
|
38 |
+
D_optim = L(build_D_optim)(
|
39 |
+
type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
|
40 |
+
lazy_regularization="${loss_fnc.lazy_regularization}",
|
41 |
+
lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
|
42 |
+
G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
|
configs/fdf/stylegan.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..generators.stylegan_unet import generator
|
2 |
+
from ..datasets.fdf256 import data
|
3 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
4 |
+
from ..defaults import train, common, EMA
|
5 |
+
|
6 |
+
train.max_images_to_train = int(35e6)
|
7 |
+
G_optim.lr = 0.002
|
8 |
+
D_optim.lr = 0.002
|
9 |
+
generator.input_cse = False
|
10 |
+
loss_fnc.r1_opts.lambd = 1
|
11 |
+
train.ims_per_val = int(2e6)
|
12 |
+
|
13 |
+
common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
|
14 |
+
common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
|
configs/fdf/stylegan_fdf128.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
2 |
+
from ..datasets.fdf128 import data
|
3 |
+
from ..generators.stylegan_unet import generator
|
4 |
+
from ..defaults import train, common, EMA
|
5 |
+
from tops.config import LazyCall as L
|
6 |
+
|
7 |
+
train.max_images_to_train = int(25e6)
|
8 |
+
G_optim.lr = 0.002
|
9 |
+
D_optim.lr = 0.002
|
10 |
+
generator.cnum = 128
|
11 |
+
generator.max_cnum_mul = 4
|
12 |
+
generator.input_cse = False
|
13 |
+
loss_fnc.r1_opts.lambd = .1
|
configs/fdh/styleganL.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from ..generators.stylegan_unet import generator
|
3 |
+
from ..datasets.fdh import data
|
4 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
5 |
+
from ..defaults import train, common, EMA
|
6 |
+
|
7 |
+
train.max_images_to_train = int(50e6)
|
8 |
+
train.batch_size = 64
|
9 |
+
G_optim.lr = 0.002
|
10 |
+
D_optim.lr = 0.002
|
11 |
+
data.train.loader.num_workers = 4
|
12 |
+
train.ims_per_val = int(1e6)
|
13 |
+
loss_fnc.r1_opts.lambd = .1
|
14 |
+
|
15 |
+
common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
|
16 |
+
common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
|
configs/fdh/styleganL_nocse.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from ..generators.stylegan_unet import generator
|
3 |
+
from ..datasets.fdh import data
|
4 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
5 |
+
from ..defaults import train, common, EMA
|
6 |
+
|
7 |
+
train.max_images_to_train = int(50e6)
|
8 |
+
G_optim.lr = 0.002
|
9 |
+
D_optim.lr = 0.002
|
10 |
+
generator.input_cse = False
|
11 |
+
data.load_embeddings = False
|
12 |
+
common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
|
13 |
+
common.model_md5sum = "fda0d809741bc67487abada793975c37"
|
14 |
+
generator.fix_errors = False
|
configs/generators/stylegan_unet.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.generator.stylegan_unet import StyleGANUnet
|
2 |
+
from tops.config import LazyCall as L
|
3 |
+
|
4 |
+
generator = L(StyleGANUnet)(
|
5 |
+
imsize="${data.imsize}",
|
6 |
+
im_channels="${data.im_channels}",
|
7 |
+
min_fmap_resolution=8,
|
8 |
+
cnum=64,
|
9 |
+
max_cnum_mul=8,
|
10 |
+
n_middle_blocks=0,
|
11 |
+
z_channels=512,
|
12 |
+
mask_output=True,
|
13 |
+
conv_clamp=256,
|
14 |
+
input_cse=True,
|
15 |
+
scale_grad=True,
|
16 |
+
cse_nc="${data.cse_nc}",
|
17 |
+
w_dim=512,
|
18 |
+
n_keypoints="${data.n_keypoints}",
|
19 |
+
input_keypoints=False,
|
20 |
+
input_keypoint_indices=[],
|
21 |
+
fix_errors=True
|
22 |
+
)
|
multi_app.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system("pip install git+https://github.com/hukkelas/deep_privacy2@36c2c843cfd3022ebc100e9f8579fb2b82f8bde6")
|
3 |
+
from collections import defaultdict
|
4 |
+
import gradio
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import cv2
|
8 |
+
from PIL import Image
|
9 |
+
from dp2 import utils
|
10 |
+
from tops.config import instantiate
|
11 |
+
import tops
|
12 |
+
import gradio.inputs
|
13 |
+
from stylemc import get_and_cache_direction, get_styles
|
14 |
+
|
15 |
+
|
16 |
+
class GuidedDemo:
|
17 |
+
def __init__(self, face_anonymizer, cfg_face) -> None:
|
18 |
+
self.anonymizer = face_anonymizer
|
19 |
+
assert sum([x is not None for x in list(face_anonymizer.generators.values())]) == 1
|
20 |
+
self.generator = [x for x in list(face_anonymizer.generators.values()) if x is not None][0]
|
21 |
+
face_G_cfg = utils.load_config(cfg_face.anonymizer.face_G_cfg)
|
22 |
+
face_G_cfg.train.batch_size = 1
|
23 |
+
self.dl = instantiate(face_G_cfg.data.val.loader)
|
24 |
+
self.cache_dir = face_G_cfg.output_dir
|
25 |
+
self.precompute_edits()
|
26 |
+
|
27 |
+
def precompute_edits(self):
|
28 |
+
self.precomputed_edits = set()
|
29 |
+
for edit in self.precomputed_edits:
|
30 |
+
get_and_cache_direction(self.cache_dir, self.dl, self.generator, edit)
|
31 |
+
if self.cache_dir.joinpath("stylemc_cache").is_dir():
|
32 |
+
for path in self.cache_dir.joinpath("stylemc_cache").iterdir():
|
33 |
+
text_prompt = path.stem.replace("_", " ")
|
34 |
+
self.precomputed_edits.add(text_prompt)
|
35 |
+
print(text_prompt)
|
36 |
+
self.edits = defaultdict(defaultdict)
|
37 |
+
|
38 |
+
def anonymize(self, img, show_boxes: bool, current_box_idx: int, current_styles, current_boxes, update_identity, edits, cache_id=None):
|
39 |
+
if not isinstance(img, torch.Tensor):
|
40 |
+
img, cache_id = pil2torch(img)
|
41 |
+
img = tops.to_cuda(img)
|
42 |
+
|
43 |
+
current_box_idx = current_box_idx % len(current_boxes)
|
44 |
+
edited_styles = [s.clone() for s in current_styles]
|
45 |
+
for face_idx, face_edits in edits.items():
|
46 |
+
for prompt, strength in face_edits.items():
|
47 |
+
direction = get_and_cache_direction(self.cache_dir, self.dl, self.generator, prompt)
|
48 |
+
edited_styles[int(face_idx)] += direction * strength
|
49 |
+
update_identity[int(face_idx)] = True
|
50 |
+
assert img.dtype == torch.uint8
|
51 |
+
img = self.anonymizer(
|
52 |
+
img, truncation_value=0,
|
53 |
+
multi_modal_truncation=True, amp=True,
|
54 |
+
cache_id=cache_id,
|
55 |
+
all_styles=edited_styles,
|
56 |
+
update_identity=update_identity)
|
57 |
+
update_identity = [True for i in range(len(update_identity))]
|
58 |
+
img = utils.im2numpy(img)
|
59 |
+
if show_boxes:
|
60 |
+
x0, y0, x1, y1 = [int(_) for _ in current_boxes[int(current_box_idx)]]
|
61 |
+
img = cv2.rectangle(img, (x0, y0), (x1, y1), (255, 0, 0), 1)
|
62 |
+
return img, update_identity
|
63 |
+
|
64 |
+
def update_image(self, img, show_boxes):
|
65 |
+
img, cache_id = pil2torch(img)
|
66 |
+
img = tops.to_cuda(img)
|
67 |
+
det = self.anonymizer.detector.forward_and_cache(img, cache_id, load_cache=True)[0]
|
68 |
+
current_styles = []
|
69 |
+
for i in range(len(det)):
|
70 |
+
s = get_styles(
|
71 |
+
np.random.randint(0, 999999),self.generator,
|
72 |
+
None, truncation_value=0)
|
73 |
+
current_styles.append(s)
|
74 |
+
update_identity = [True for i in range(len(det))]
|
75 |
+
current_boxes = np.array(det.boxes)
|
76 |
+
edits = defaultdict(defaultdict)
|
77 |
+
cur_face_idx = -1 % len(current_boxes)
|
78 |
+
img, update_identity = self.anonymize(img, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits, cache_id=cache_id)
|
79 |
+
return img, current_styles, current_boxes, update_identity, edits, cur_face_idx
|
80 |
+
|
81 |
+
def change_face(self, change, cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits):
|
82 |
+
cur_face_idx = (cur_face_idx+change) % len(current_boxes)
|
83 |
+
img, update_identity = self.anonymize(input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits)
|
84 |
+
return img, update_identity, cur_face_idx
|
85 |
+
|
86 |
+
def add_style(self, face_idx: int, prompt: str, strength: float, input_image, show_boxes, current_styles, current_boxes, update_identity, edits):
|
87 |
+
face_idx = face_idx % len(current_boxes)
|
88 |
+
edits[face_idx][prompt] = strength
|
89 |
+
img, update_identity = self.anonymize(input_image, show_boxes, face_idx, current_styles, current_boxes, update_identity, edits)
|
90 |
+
return img, update_identity, edits
|
91 |
+
|
92 |
+
def setup_interface(self):
|
93 |
+
current_styles = gradio.State()
|
94 |
+
current_boxes = gradio.State(None)
|
95 |
+
update_identity = gradio.State([])
|
96 |
+
edits = gradio.State([])
|
97 |
+
with gradio.Row():
|
98 |
+
input_image = gradio.Image(
|
99 |
+
type="pil", label="Upload your image or try the example below!",source="webcam")
|
100 |
+
output_image = gradio.Image(type="numpy", label="Output")
|
101 |
+
with gradio.Row():
|
102 |
+
update_btn = gradio.Button("Update Anonymization").style(full_width=True)
|
103 |
+
with gradio.Row():
|
104 |
+
show_boxes = gradio.Checkbox(value=True, label="Show Selected")
|
105 |
+
cur_face_idx = gradio.Number(value=-1,label="Current", interactive=False)
|
106 |
+
previous = gradio.Button("Previous Person")
|
107 |
+
next_ = gradio.Button("Next Person")
|
108 |
+
with gradio.Row():
|
109 |
+
text_prompt = gradio.Textbox(
|
110 |
+
placeholder=" | ".join(list(self.precomputed_edits)),
|
111 |
+
label="Text Prompt for Edit")
|
112 |
+
edit_strength = gradio.Slider(0, 5, step=.01)
|
113 |
+
add_btn = gradio.Button("Add Edit")
|
114 |
+
add_btn.click(self.add_style, inputs=[cur_face_idx, text_prompt, edit_strength, input_image, show_boxes, current_styles, current_boxes, update_identity, edits], outputs=[output_image, update_identity, edits])
|
115 |
+
update_btn.click(self.update_image, inputs=[input_image, show_boxes], outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
|
116 |
+
input_image.change(self.update_image, inputs=[input_image, show_boxes], outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
|
117 |
+
previous.click(self.change_face, inputs=[gradio.State(-1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], outputs=[output_image, update_identity, cur_face_idx])
|
118 |
+
next_.click(self.change_face, inputs=[gradio.State(1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], outputs=[output_image, update_identity, cur_face_idx])
|
119 |
+
|
120 |
+
show_boxes.change(self.anonymize, inputs=[input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits], outputs=[output_image, update_identity])
|
121 |
+
|
122 |
+
|
123 |
+
cfg_body = utils.load_config("configs/anonymizers/FB_cse.py")
|
124 |
+
anonymizer_body = instantiate(cfg_body.anonymizer, load_cache=False)
|
125 |
+
anonymizer_body.initialize_tracker(fps=1)
|
126 |
+
cfg_face = utils.load_config("configs/anonymizers/face.py")
|
127 |
+
anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
|
128 |
+
anonymizer_face.initialize_tracker(fps=1)
|
129 |
+
|
130 |
+
class WebcamDemo:
|
131 |
+
|
132 |
+
def __init__(self, anonymizer) -> None:
|
133 |
+
self.anonymizer = anonymizer
|
134 |
+
with gradio.Row():
|
135 |
+
input_image = gradio.Image(type="pil", source="webcam", streaming=True)
|
136 |
+
output_image = gradio.Image(type="numpy", label="Output")
|
137 |
+
visualize_det = gradio.Checkbox(value=False, label="Show Detections")
|
138 |
+
input_image.stream(self.anonymize, [input_image, visualize_det], [output_image])
|
139 |
+
self.track = True
|
140 |
+
|
141 |
+
def anonymize(self, img: Image, visualize_detection: bool):
|
142 |
+
img, cache_id = pil2torch(img)
|
143 |
+
img = tops.to_cuda(img)
|
144 |
+
if visualize_detection:
|
145 |
+
img = self.anonymizer.visualize_detection(img, cache_id=cache_id)
|
146 |
+
else:
|
147 |
+
img = self.anonymizer(
|
148 |
+
img, truncation_value=0, multi_modal_truncation=True, amp=True,
|
149 |
+
cache_id=cache_id, track=self.track)
|
150 |
+
img = utils.im2numpy(img)
|
151 |
+
return img
|
152 |
+
|
153 |
+
class ExampleDemo(WebcamDemo):
|
154 |
+
|
155 |
+
def __init__(self, anonymizer) -> None:
|
156 |
+
self.anonymizer = anonymizer
|
157 |
+
with gradio.Row():
|
158 |
+
input_image = gradio.Image(type="pil", source="webcam")
|
159 |
+
output_image = gradio.Image(type="numpy", label="Output")
|
160 |
+
with gradio.Row():
|
161 |
+
update_btn = gradio.Button("Update Anonymization").style(full_width=True)
|
162 |
+
visualize_det = gradio.Checkbox(value=False, label="Show Detections")
|
163 |
+
visualize_det.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
|
164 |
+
gradio.Examples(
|
165 |
+
["media2/erling.jpg", "media2/regjeringen.jpg"], inputs=[input_image]
|
166 |
+
)
|
167 |
+
update_btn.click(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
|
168 |
+
input_image.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
|
169 |
+
self.track = False
|
170 |
+
|
171 |
+
|
172 |
+
class Information:
|
173 |
+
|
174 |
+
def __init__(self) -> None:
|
175 |
+
gradio.Markdown("## <center> Face Anonymization Architecture </center>")
|
176 |
+
gradio.Markdown("---")
|
177 |
+
gradio.Image(value="media2/overall_architecture.png")
|
178 |
+
gradio.Markdown("## <center> Full-Body Anonymization Architecture </center>")
|
179 |
+
gradio.Markdown("---")
|
180 |
+
gradio.Image(value="media2/full_body.png")
|
181 |
+
gradio.Markdown("### <center> Generative Adversarial Networks </center>")
|
182 |
+
gradio.Markdown("---")
|
183 |
+
gradio.Image(value="media2/gan_architecture.png")
|
184 |
+
|
185 |
+
|
186 |
+
def pil2torch(img: Image.Image):
|
187 |
+
img = img.convert("RGB")
|
188 |
+
img = np.array(img)
|
189 |
+
img = np.rollaxis(img, 2)
|
190 |
+
return torch.from_numpy(img), None
|
191 |
+
|
192 |
+
|
193 |
+
with gradio.Blocks() as demo:
|
194 |
+
gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
|
195 |
+
gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
|
196 |
+
with gradio.Tab("Text-Guided Anonymization"):
|
197 |
+
GuidedDemo(anonymizer_face, cfg_face).setup_interface()
|
198 |
+
with gradio.Tab("Live Full-Body"):
|
199 |
+
WebcamDemo(anonymizer_body)
|
200 |
+
with gradio.Tab("Live Face"):
|
201 |
+
WebcamDemo(anonymizer_face)
|
202 |
+
|
203 |
+
|
204 |
+
demo.launch()
|
sg3_torch_utils/LICENSE.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
NVIDIA Source Code License for StyleGAN3
|
5 |
+
|
6 |
+
|
7 |
+
=======================================================================
|
8 |
+
|
9 |
+
1. Definitions
|
10 |
+
|
11 |
+
"Licensor" means any person or entity that distributes its Work.
|
12 |
+
|
13 |
+
"Software" means the original work of authorship made available under
|
14 |
+
this License.
|
15 |
+
|
16 |
+
"Work" means the Software and any additions to or derivative works of
|
17 |
+
the Software that are made available under this License.
|
18 |
+
|
19 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
20 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
21 |
+
provided, however, that for the purposes of this License, derivative
|
22 |
+
works shall not include works that remain separable from, or merely
|
23 |
+
link (or bind by name) to the interfaces of, the Work.
|
24 |
+
|
25 |
+
Works, including the Software, are "made available" under this License
|
26 |
+
by including in or with the Work either (a) a copyright notice
|
27 |
+
referencing the applicability of this License to the Work, or (b) a
|
28 |
+
copy of this License.
|
29 |
+
|
30 |
+
2. License Grants
|
31 |
+
|
32 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
33 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
34 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
35 |
+
prepare derivative works of, publicly display, publicly perform,
|
36 |
+
sublicense and distribute its Work and any resulting derivative
|
37 |
+
works in any form.
|
38 |
+
|
39 |
+
3. Limitations
|
40 |
+
|
41 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
42 |
+
if (a) you do so under this License, (b) you include a complete
|
43 |
+
copy of this License with your distribution, and (c) you retain
|
44 |
+
without modification any copyright, patent, trademark, or
|
45 |
+
attribution notices that are present in the Work.
|
46 |
+
|
47 |
+
3.2 Derivative Works. You may specify that additional or different
|
48 |
+
terms apply to the use, reproduction, and distribution of your
|
49 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
50 |
+
provide that the use limitation in Section 3.3 applies to your
|
51 |
+
derivative works, and (b) you identify the specific derivative
|
52 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
53 |
+
this License (including the redistribution requirements in Section
|
54 |
+
3.1) will continue to apply to the Work itself.
|
55 |
+
|
56 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
57 |
+
may be used or intended for use non-commercially. Notwithstanding
|
58 |
+
the foregoing, NVIDIA and its affiliates may use the Work and any
|
59 |
+
derivative works commercially. As used herein, "non-commercially"
|
60 |
+
means for research or evaluation purposes only.
|
61 |
+
|
62 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
63 |
+
against any Licensor (including any claim, cross-claim or
|
64 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
65 |
+
are infringed by any Work, then your rights under this License from
|
66 |
+
such Licensor (including the grant in Section 2.1) will terminate
|
67 |
+
immediately.
|
68 |
+
|
69 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
70 |
+
Licensor’s or its affiliates’ names, logos, or trademarks, except
|
71 |
+
as necessary to reproduce the notices described in this License.
|
72 |
+
|
73 |
+
3.6 Termination. If you violate any term of this License, then your
|
74 |
+
rights under this License (including the grant in Section 2.1) will
|
75 |
+
terminate immediately.
|
76 |
+
|
77 |
+
4. Disclaimer of Warranty.
|
78 |
+
|
79 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
80 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
81 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
82 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
83 |
+
THIS LICENSE.
|
84 |
+
|
85 |
+
5. Limitation of Liability.
|
86 |
+
|
87 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
88 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
89 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
90 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
91 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
92 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
93 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
94 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
95 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
96 |
+
|
97 |
+
=======================================================================
|
sg3_torch_utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
sg3_torch_utils/custom_ops.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import torch
|
12 |
+
import torch.utils.cpp_extension
|
13 |
+
import importlib
|
14 |
+
import hashlib
|
15 |
+
import shutil
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
from torch.utils.file_baton import FileBaton
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
# Global options.
|
22 |
+
|
23 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
# Internal helper funcs.
|
27 |
+
|
28 |
+
def _find_compiler_bindir():
|
29 |
+
patterns = [
|
30 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
31 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
32 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
33 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
34 |
+
]
|
35 |
+
for pattern in patterns:
|
36 |
+
matches = sorted(glob.glob(pattern))
|
37 |
+
if len(matches):
|
38 |
+
return matches[-1]
|
39 |
+
return None
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
42 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
43 |
+
|
44 |
+
_cached_plugins = dict()
|
45 |
+
|
46 |
+
def get_plugin(module_name, sources, **build_kwargs):
|
47 |
+
assert verbosity in ['none', 'brief', 'full']
|
48 |
+
|
49 |
+
# Already cached?
|
50 |
+
if module_name in _cached_plugins:
|
51 |
+
return _cached_plugins[module_name]
|
52 |
+
|
53 |
+
# Print status.
|
54 |
+
if verbosity == 'full':
|
55 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
56 |
+
elif verbosity == 'brief':
|
57 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
58 |
+
|
59 |
+
try: # pylint: disable=too-many-nested-blocks
|
60 |
+
# Make sure we can find the necessary compiler binaries.
|
61 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
62 |
+
compiler_bindir = _find_compiler_bindir()
|
63 |
+
if compiler_bindir is None:
|
64 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
65 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
66 |
+
|
67 |
+
# Compile and load.
|
68 |
+
verbose_build = (verbosity == 'full')
|
69 |
+
|
70 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
71 |
+
# into a cached build directory under a combined md5 digest of the input
|
72 |
+
# source files. Copying is done only if the combined digest has changed.
|
73 |
+
# This keeps input file timestamps and filenames the same as in previous
|
74 |
+
# extension builds, allowing for fast incremental rebuilds.
|
75 |
+
#
|
76 |
+
# This optimization is done only in case all the source files reside in
|
77 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
78 |
+
# environment variable is set (we take this as a signal that the user
|
79 |
+
# actually cares about this.)
|
80 |
+
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
81 |
+
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
82 |
+
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
83 |
+
|
84 |
+
# Compute a combined hash digest for all source files in the same
|
85 |
+
# custom op directory (usually .cu, .cpp, .py and .h files).
|
86 |
+
hash_md5 = hashlib.md5()
|
87 |
+
for src in all_source_files:
|
88 |
+
with open(src, 'rb') as f:
|
89 |
+
hash_md5.update(f.read())
|
90 |
+
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
91 |
+
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
92 |
+
|
93 |
+
if not os.path.isdir(digest_build_dir):
|
94 |
+
os.makedirs(digest_build_dir, exist_ok=True)
|
95 |
+
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
96 |
+
if baton.try_acquire():
|
97 |
+
try:
|
98 |
+
for src in all_source_files:
|
99 |
+
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
100 |
+
finally:
|
101 |
+
baton.release()
|
102 |
+
else:
|
103 |
+
# Someone else is copying source files under the digest dir,
|
104 |
+
# wait until done and continue.
|
105 |
+
baton.wait()
|
106 |
+
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
107 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
108 |
+
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
109 |
+
else:
|
110 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
111 |
+
module = importlib.import_module(module_name)
|
112 |
+
|
113 |
+
except:
|
114 |
+
if verbosity == 'brief':
|
115 |
+
print('Failed!')
|
116 |
+
raise
|
117 |
+
|
118 |
+
# Print status and add to cache.
|
119 |
+
if verbosity == 'full':
|
120 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
121 |
+
elif verbosity == 'brief':
|
122 |
+
print('Done.')
|
123 |
+
_cached_plugins[module_name] = module
|
124 |
+
return module
|
125 |
+
|
126 |
+
#----------------------------------------------------------------------------
|
sg3_torch_utils/misc.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import re
|
10 |
+
import contextlib
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
#----------------------------------------------------------------------------
|
16 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
17 |
+
# same constant is used multiple times.
|
18 |
+
|
19 |
+
_constant_cache = dict()
|
20 |
+
|
21 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
22 |
+
value = np.asarray(value)
|
23 |
+
if shape is not None:
|
24 |
+
shape = tuple(shape)
|
25 |
+
if dtype is None:
|
26 |
+
dtype = torch.get_default_dtype()
|
27 |
+
if device is None:
|
28 |
+
device = torch.device('cpu')
|
29 |
+
if memory_format is None:
|
30 |
+
memory_format = torch.contiguous_format
|
31 |
+
|
32 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
33 |
+
tensor = _constant_cache.get(key, None)
|
34 |
+
if tensor is None:
|
35 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
36 |
+
if shape is not None:
|
37 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
38 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
39 |
+
_constant_cache[key] = tensor
|
40 |
+
return tensor
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
# Replace NaN/Inf with specified numerical values.
|
44 |
+
|
45 |
+
try:
|
46 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
47 |
+
except AttributeError:
|
48 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
49 |
+
assert isinstance(input, torch.Tensor)
|
50 |
+
if posinf is None:
|
51 |
+
posinf = torch.finfo(input.dtype).max
|
52 |
+
if neginf is None:
|
53 |
+
neginf = torch.finfo(input.dtype).min
|
54 |
+
assert nan == 0
|
55 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
56 |
+
|
57 |
+
#----------------------------------------------------------------------------
|
58 |
+
# Symbolic assert.
|
59 |
+
|
60 |
+
try:
|
61 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
62 |
+
except AttributeError:
|
63 |
+
symbolic_assert = torch.Assert # 1.7.0
|
64 |
+
|
65 |
+
#----------------------------------------------------------------------------
|
66 |
+
# Context manager to suppress known warnings in torch.jit.trace().
|
67 |
+
|
68 |
+
class suppress_tracer_warnings(warnings.catch_warnings):
|
69 |
+
def __enter__(self):
|
70 |
+
super().__enter__()
|
71 |
+
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
72 |
+
return self
|
73 |
+
|
74 |
+
#----------------------------------------------------------------------------
|
75 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
76 |
+
# None indicates that the size of a dimension is allowed to vary.
|
77 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
78 |
+
|
79 |
+
def assert_shape(tensor, ref_shape):
|
80 |
+
if tensor.ndim != len(ref_shape):
|
81 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
82 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
83 |
+
if ref_size is None:
|
84 |
+
pass
|
85 |
+
elif isinstance(ref_size, torch.Tensor):
|
86 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
87 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
88 |
+
elif isinstance(size, torch.Tensor):
|
89 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
90 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
91 |
+
elif size != ref_size:
|
92 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
93 |
+
|
94 |
+
#----------------------------------------------------------------------------
|
95 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
96 |
+
|
97 |
+
def profiled_function(fn):
|
98 |
+
def decorator(*args, **kwargs):
|
99 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
100 |
+
return fn(*args, **kwargs)
|
101 |
+
decorator.__name__ = fn.__name__
|
102 |
+
return decorator
|
103 |
+
|
104 |
+
#----------------------------------------------------------------------------
|
105 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
106 |
+
# indefinitely, shuffling items as it goes.
|
107 |
+
|
108 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
109 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
110 |
+
assert len(dataset) > 0
|
111 |
+
assert num_replicas > 0
|
112 |
+
assert 0 <= rank < num_replicas
|
113 |
+
assert 0 <= window_size <= 1
|
114 |
+
super().__init__(dataset)
|
115 |
+
self.dataset = dataset
|
116 |
+
self.rank = rank
|
117 |
+
self.num_replicas = num_replicas
|
118 |
+
self.shuffle = shuffle
|
119 |
+
self.seed = seed
|
120 |
+
self.window_size = window_size
|
121 |
+
|
122 |
+
def __iter__(self):
|
123 |
+
order = np.arange(len(self.dataset))
|
124 |
+
rnd = None
|
125 |
+
window = 0
|
126 |
+
if self.shuffle:
|
127 |
+
rnd = np.random.RandomState(self.seed)
|
128 |
+
rnd.shuffle(order)
|
129 |
+
window = int(np.rint(order.size * self.window_size))
|
130 |
+
|
131 |
+
idx = 0
|
132 |
+
while True:
|
133 |
+
i = idx % order.size
|
134 |
+
if idx % self.num_replicas == self.rank:
|
135 |
+
yield order[i]
|
136 |
+
if window >= 2:
|
137 |
+
j = (i - rnd.randint(window)) % order.size
|
138 |
+
order[i], order[j] = order[j], order[i]
|
139 |
+
idx += 1
|
140 |
+
|
141 |
+
#----------------------------------------------------------------------------
|
142 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
143 |
+
|
144 |
+
def params_and_buffers(module):
|
145 |
+
assert isinstance(module, torch.nn.Module)
|
146 |
+
return list(module.parameters()) + list(module.buffers())
|
147 |
+
|
148 |
+
def named_params_and_buffers(module):
|
149 |
+
assert isinstance(module, torch.nn.Module)
|
150 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
151 |
+
|
152 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
153 |
+
assert isinstance(src_module, torch.nn.Module)
|
154 |
+
assert isinstance(dst_module, torch.nn.Module)
|
155 |
+
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
|
156 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
157 |
+
assert (name in src_tensors) or (not require_all)
|
158 |
+
if name in src_tensors:
|
159 |
+
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
160 |
+
|
161 |
+
#----------------------------------------------------------------------------
|
162 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
163 |
+
# synchronization.
|
164 |
+
|
165 |
+
@contextlib.contextmanager
|
166 |
+
def ddp_sync(module, sync):
|
167 |
+
assert isinstance(module, torch.nn.Module)
|
168 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
169 |
+
yield
|
170 |
+
else:
|
171 |
+
with module.no_sync():
|
172 |
+
yield
|
sg3_torch_utils/ops/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
sg3_torch_utils/ops/bias_act.cpp
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "bias_act.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
17 |
+
{
|
18 |
+
if (x.dim() != y.dim())
|
19 |
+
return false;
|
20 |
+
for (int64_t i = 0; i < x.dim(); i++)
|
21 |
+
{
|
22 |
+
if (x.size(i) != y.size(i))
|
23 |
+
return false;
|
24 |
+
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
25 |
+
return false;
|
26 |
+
}
|
27 |
+
return true;
|
28 |
+
}
|
29 |
+
|
30 |
+
//------------------------------------------------------------------------
|
31 |
+
|
32 |
+
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
33 |
+
{
|
34 |
+
// Validate arguments.
|
35 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
36 |
+
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
37 |
+
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
38 |
+
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
39 |
+
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
40 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
41 |
+
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
42 |
+
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
43 |
+
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
44 |
+
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
45 |
+
|
46 |
+
// Validate layout.
|
47 |
+
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
48 |
+
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
49 |
+
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
50 |
+
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
51 |
+
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
52 |
+
|
53 |
+
// Create output tensor.
|
54 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
55 |
+
torch::Tensor y = torch::empty_like(x);
|
56 |
+
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
57 |
+
|
58 |
+
// Initialize CUDA kernel parameters.
|
59 |
+
bias_act_kernel_params p;
|
60 |
+
p.x = x.data_ptr();
|
61 |
+
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
62 |
+
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
63 |
+
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
64 |
+
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
65 |
+
p.y = y.data_ptr();
|
66 |
+
p.grad = grad;
|
67 |
+
p.act = act;
|
68 |
+
p.alpha = alpha;
|
69 |
+
p.gain = gain;
|
70 |
+
p.clamp = clamp;
|
71 |
+
p.sizeX = (int)x.numel();
|
72 |
+
p.sizeB = (int)b.numel();
|
73 |
+
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
74 |
+
|
75 |
+
// Choose CUDA kernel.
|
76 |
+
void* kernel;
|
77 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
78 |
+
{
|
79 |
+
kernel = choose_bias_act_kernel<scalar_t>(p);
|
80 |
+
});
|
81 |
+
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
82 |
+
|
83 |
+
// Launch CUDA kernel.
|
84 |
+
p.loopX = 4;
|
85 |
+
int blockSize = 4 * 32;
|
86 |
+
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
87 |
+
void* args[] = {&p};
|
88 |
+
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
89 |
+
return y;
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
|
94 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
95 |
+
{
|
96 |
+
m.def("bias_act", &bias_act);
|
97 |
+
}
|
98 |
+
|
99 |
+
//------------------------------------------------------------------------
|
sg3_torch_utils/ops/bias_act.cu
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "bias_act.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
//------------------------------------------------------------------------
|
21 |
+
// CUDA kernel.
|
22 |
+
|
23 |
+
template <class T, int A>
|
24 |
+
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
25 |
+
{
|
26 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
27 |
+
int G = p.grad;
|
28 |
+
scalar_t alpha = (scalar_t)p.alpha;
|
29 |
+
scalar_t gain = (scalar_t)p.gain;
|
30 |
+
scalar_t clamp = (scalar_t)p.clamp;
|
31 |
+
scalar_t one = (scalar_t)1;
|
32 |
+
scalar_t two = (scalar_t)2;
|
33 |
+
scalar_t expRange = (scalar_t)80;
|
34 |
+
scalar_t halfExpRange = (scalar_t)40;
|
35 |
+
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
36 |
+
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
37 |
+
|
38 |
+
// Loop over elements.
|
39 |
+
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
40 |
+
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
41 |
+
{
|
42 |
+
// Load.
|
43 |
+
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
44 |
+
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
45 |
+
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
46 |
+
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
47 |
+
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
48 |
+
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
49 |
+
scalar_t y = 0;
|
50 |
+
|
51 |
+
// Apply bias.
|
52 |
+
((G == 0) ? x : xref) += b;
|
53 |
+
|
54 |
+
// linear
|
55 |
+
if (A == 1)
|
56 |
+
{
|
57 |
+
if (G == 0) y = x;
|
58 |
+
if (G == 1) y = x;
|
59 |
+
}
|
60 |
+
|
61 |
+
// relu
|
62 |
+
if (A == 2)
|
63 |
+
{
|
64 |
+
if (G == 0) y = (x > 0) ? x : 0;
|
65 |
+
if (G == 1) y = (yy > 0) ? x : 0;
|
66 |
+
}
|
67 |
+
|
68 |
+
// lrelu
|
69 |
+
if (A == 3)
|
70 |
+
{
|
71 |
+
if (G == 0) y = (x > 0) ? x : x * alpha;
|
72 |
+
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
73 |
+
}
|
74 |
+
|
75 |
+
// tanh
|
76 |
+
if (A == 4)
|
77 |
+
{
|
78 |
+
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
79 |
+
if (G == 1) y = x * (one - yy * yy);
|
80 |
+
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
81 |
+
}
|
82 |
+
|
83 |
+
// sigmoid
|
84 |
+
if (A == 5)
|
85 |
+
{
|
86 |
+
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
87 |
+
if (G == 1) y = x * yy * (one - yy);
|
88 |
+
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
89 |
+
}
|
90 |
+
|
91 |
+
// elu
|
92 |
+
if (A == 6)
|
93 |
+
{
|
94 |
+
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
95 |
+
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
96 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
97 |
+
}
|
98 |
+
|
99 |
+
// selu
|
100 |
+
if (A == 7)
|
101 |
+
{
|
102 |
+
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
103 |
+
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
104 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
105 |
+
}
|
106 |
+
|
107 |
+
// softplus
|
108 |
+
if (A == 8)
|
109 |
+
{
|
110 |
+
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
111 |
+
if (G == 1) y = x * (one - exp(-yy));
|
112 |
+
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
113 |
+
}
|
114 |
+
|
115 |
+
// swish
|
116 |
+
if (A == 9)
|
117 |
+
{
|
118 |
+
if (G == 0)
|
119 |
+
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
120 |
+
else
|
121 |
+
{
|
122 |
+
scalar_t c = exp(xref);
|
123 |
+
scalar_t d = c + one;
|
124 |
+
if (G == 1)
|
125 |
+
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
126 |
+
else
|
127 |
+
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
128 |
+
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
// Apply gain.
|
133 |
+
y *= gain * dy;
|
134 |
+
|
135 |
+
// Clamp.
|
136 |
+
if (clamp >= 0)
|
137 |
+
{
|
138 |
+
if (G == 0)
|
139 |
+
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
140 |
+
else
|
141 |
+
y = (yref > -clamp & yref < clamp) ? y : 0;
|
142 |
+
}
|
143 |
+
|
144 |
+
// Store.
|
145 |
+
((T*)p.y)[xi] = (T)y;
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
//------------------------------------------------------------------------
|
150 |
+
// CUDA kernel selection.
|
151 |
+
|
152 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
153 |
+
{
|
154 |
+
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
155 |
+
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
156 |
+
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
157 |
+
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
158 |
+
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
159 |
+
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
160 |
+
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
161 |
+
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
162 |
+
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
163 |
+
return NULL;
|
164 |
+
}
|
165 |
+
|
166 |
+
//------------------------------------------------------------------------
|
167 |
+
// Template specializations.
|
168 |
+
|
169 |
+
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
170 |
+
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
171 |
+
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
172 |
+
|
173 |
+
//------------------------------------------------------------------------
|
sg3_torch_utils/ops/bias_act.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
//------------------------------------------------------------------------
|
10 |
+
// CUDA kernel parameters.
|
11 |
+
|
12 |
+
struct bias_act_kernel_params
|
13 |
+
{
|
14 |
+
const void* x; // [sizeX]
|
15 |
+
const void* b; // [sizeB] or NULL
|
16 |
+
const void* xref; // [sizeX] or NULL
|
17 |
+
const void* yref; // [sizeX] or NULL
|
18 |
+
const void* dy; // [sizeX] or NULL
|
19 |
+
void* y; // [sizeX]
|
20 |
+
|
21 |
+
int grad;
|
22 |
+
int act;
|
23 |
+
float alpha;
|
24 |
+
float gain;
|
25 |
+
float clamp;
|
26 |
+
|
27 |
+
int sizeX;
|
28 |
+
int sizeB;
|
29 |
+
int stepB;
|
30 |
+
int loopX;
|
31 |
+
};
|
32 |
+
|
33 |
+
//------------------------------------------------------------------------
|
34 |
+
// CUDA kernel selection.
|
35 |
+
|
36 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
37 |
+
|
38 |
+
//------------------------------------------------------------------------
|
sg3_torch_utils/ops/bias_act.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient bias and activation."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import traceback
|
16 |
+
|
17 |
+
from .. import custom_ops
|
18 |
+
from easydict import EasyDict
|
19 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
activation_funcs = {
|
23 |
+
'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
24 |
+
'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
25 |
+
'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
26 |
+
'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
27 |
+
'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
28 |
+
'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
29 |
+
'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
30 |
+
'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
31 |
+
'swish': EasyDict(func=lambda x, **_: torch.nn.functional.silu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
32 |
+
}
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
_inited = False
|
37 |
+
_plugin = None
|
38 |
+
enabled = False
|
39 |
+
_null_tensor = torch.empty([0])
|
40 |
+
|
41 |
+
def _init():
|
42 |
+
global _inited, _plugin
|
43 |
+
if not _inited:
|
44 |
+
_inited = True
|
45 |
+
sources = ['bias_act.cpp', 'bias_act.cu']
|
46 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
47 |
+
try:
|
48 |
+
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
49 |
+
except:
|
50 |
+
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
51 |
+
return _plugin is not None
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
|
55 |
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
56 |
+
r"""Fused bias and activation function.
|
57 |
+
|
58 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
59 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
60 |
+
the fused op is considerably more efficient than performing the same calculation
|
61 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
62 |
+
but not third order gradients.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
x: Input activation tensor. Can be of any shape.
|
66 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
67 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
68 |
+
corresponding to `dim`.
|
69 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
70 |
+
The value of `dim` is ignored if `b` is not specified.
|
71 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
72 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
73 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
74 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
75 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
76 |
+
See `activation_funcs` for the default scaling of each activation function.
|
77 |
+
If unsure, consider specifying 1.
|
78 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
79 |
+
the clamping (default).
|
80 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Tensor of the same shape and datatype as `x`.
|
84 |
+
"""
|
85 |
+
assert isinstance(x, torch.Tensor)
|
86 |
+
assert impl in ['ref', 'cuda']
|
87 |
+
if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init():
|
88 |
+
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
89 |
+
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
90 |
+
|
91 |
+
#----------------------------------------------------------------------------
|
92 |
+
|
93 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
94 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
95 |
+
"""
|
96 |
+
assert isinstance(x, torch.Tensor)
|
97 |
+
assert clamp is None or clamp >= 0
|
98 |
+
spec = activation_funcs[act]
|
99 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
100 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
101 |
+
clamp = float(clamp if clamp is not None else -1)
|
102 |
+
|
103 |
+
# Add bias.
|
104 |
+
if b is not None:
|
105 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
106 |
+
assert 0 <= dim < x.ndim
|
107 |
+
assert b.shape[0] == x.shape[dim]
|
108 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
109 |
+
|
110 |
+
# Evaluate activation function.
|
111 |
+
alpha = float(alpha)
|
112 |
+
x = spec.func(x, alpha=alpha)
|
113 |
+
|
114 |
+
# Scale by gain.
|
115 |
+
gain = float(gain)
|
116 |
+
if gain != 1:
|
117 |
+
x = x * gain
|
118 |
+
|
119 |
+
# Clamp.
|
120 |
+
if clamp >= 0:
|
121 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
122 |
+
return x
|
123 |
+
|
124 |
+
#----------------------------------------------------------------------------
|
125 |
+
|
126 |
+
_bias_act_cuda_cache = dict()
|
127 |
+
|
128 |
+
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
129 |
+
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
130 |
+
"""
|
131 |
+
# Parse arguments.
|
132 |
+
assert clamp is None or clamp >= 0
|
133 |
+
spec = activation_funcs[act]
|
134 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
135 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
136 |
+
clamp = float(clamp if clamp is not None else -1)
|
137 |
+
|
138 |
+
# Lookup from cache.
|
139 |
+
key = (dim, act, alpha, gain, clamp)
|
140 |
+
if key in _bias_act_cuda_cache:
|
141 |
+
return _bias_act_cuda_cache[key]
|
142 |
+
|
143 |
+
# Forward op.
|
144 |
+
class BiasActCuda(torch.autograd.Function):
|
145 |
+
@staticmethod
|
146 |
+
@custom_fwd(cast_inputs=torch.float16)
|
147 |
+
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
148 |
+
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
|
149 |
+
x = x.contiguous(memory_format=ctx.memory_format)
|
150 |
+
b = b.contiguous() if b is not None else _null_tensor
|
151 |
+
y = x
|
152 |
+
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
153 |
+
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
154 |
+
ctx.save_for_backward(
|
155 |
+
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
156 |
+
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
157 |
+
y if 'y' in spec.ref else _null_tensor)
|
158 |
+
return y
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
@custom_bwd
|
162 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
163 |
+
dy = dy.contiguous(memory_format=ctx.memory_format)
|
164 |
+
x, b, y = ctx.saved_tensors
|
165 |
+
dx = None
|
166 |
+
db = None
|
167 |
+
|
168 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
169 |
+
dx = dy
|
170 |
+
if act != 'linear' or gain != 1 or clamp >= 0:
|
171 |
+
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
172 |
+
|
173 |
+
if ctx.needs_input_grad[1]:
|
174 |
+
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
175 |
+
|
176 |
+
return dx, db
|
177 |
+
|
178 |
+
# Backward op.
|
179 |
+
class BiasActCudaGrad(torch.autograd.Function):
|
180 |
+
@staticmethod
|
181 |
+
@custom_fwd(cast_inputs=torch.float16)
|
182 |
+
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
183 |
+
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
|
184 |
+
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
185 |
+
ctx.save_for_backward(
|
186 |
+
dy if spec.has_2nd_grad else _null_tensor,
|
187 |
+
x, b, y)
|
188 |
+
return dx
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
@custom_bwd
|
192 |
+
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
193 |
+
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
194 |
+
dy, x, b, y = ctx.saved_tensors
|
195 |
+
d_dy = None
|
196 |
+
d_x = None
|
197 |
+
d_b = None
|
198 |
+
d_y = None
|
199 |
+
|
200 |
+
if ctx.needs_input_grad[0]:
|
201 |
+
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
202 |
+
|
203 |
+
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
204 |
+
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
205 |
+
|
206 |
+
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
207 |
+
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
208 |
+
|
209 |
+
return d_dy, d_x, d_b, d_y
|
210 |
+
|
211 |
+
# Add to cache.
|
212 |
+
_bias_act_cuda_cache[key] = BiasActCuda
|
213 |
+
return BiasActCuda
|
214 |
+
|
215 |
+
#----------------------------------------------------------------------------
|
sg3_torch_utils/ops/conv2d_gradfix.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
10 |
+
arbitrarily high order gradients with zero performance penalty."""
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
import contextlib
|
14 |
+
import torch
|
15 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
16 |
+
|
17 |
+
# pylint: disable=redefined-builtin
|
18 |
+
# pylint: disable=arguments-differ
|
19 |
+
# pylint: disable=protected-access
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
enabled = False # Enable the custom op by setting this to true.
|
24 |
+
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
25 |
+
|
26 |
+
@contextlib.contextmanager
|
27 |
+
def no_weight_gradients():
|
28 |
+
global weight_gradients_disabled
|
29 |
+
old = weight_gradients_disabled
|
30 |
+
weight_gradients_disabled = True
|
31 |
+
yield
|
32 |
+
weight_gradients_disabled = old
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
37 |
+
if _should_use_custom_op(input):
|
38 |
+
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
39 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
40 |
+
|
41 |
+
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
42 |
+
if _should_use_custom_op(input):
|
43 |
+
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
44 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
45 |
+
|
46 |
+
#----------------------------------------------------------------------------
|
47 |
+
|
48 |
+
def _should_use_custom_op(input):
|
49 |
+
assert isinstance(input, torch.Tensor)
|
50 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
51 |
+
return False
|
52 |
+
if input.device.type != 'cuda':
|
53 |
+
return False
|
54 |
+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.10']):
|
55 |
+
return True
|
56 |
+
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
|
57 |
+
return False
|
58 |
+
|
59 |
+
def _tuple_of_ints(xs, ndim):
|
60 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
61 |
+
assert len(xs) == ndim
|
62 |
+
assert all(isinstance(x, int) for x in xs)
|
63 |
+
return xs
|
64 |
+
|
65 |
+
#----------------------------------------------------------------------------
|
66 |
+
|
67 |
+
_conv2d_gradfix_cache = dict()
|
68 |
+
|
69 |
+
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
70 |
+
# Parse arguments.
|
71 |
+
ndim = 2
|
72 |
+
weight_shape = tuple(weight_shape)
|
73 |
+
stride = _tuple_of_ints(stride, ndim)
|
74 |
+
padding = _tuple_of_ints(padding, ndim)
|
75 |
+
output_padding = _tuple_of_ints(output_padding, ndim)
|
76 |
+
dilation = _tuple_of_ints(dilation, ndim)
|
77 |
+
|
78 |
+
# Lookup from cache.
|
79 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
80 |
+
if key in _conv2d_gradfix_cache:
|
81 |
+
return _conv2d_gradfix_cache[key]
|
82 |
+
|
83 |
+
# Validate arguments.
|
84 |
+
assert groups >= 1
|
85 |
+
assert len(weight_shape) == ndim + 2
|
86 |
+
assert all(stride[i] >= 1 for i in range(ndim))
|
87 |
+
assert all(padding[i] >= 0 for i in range(ndim))
|
88 |
+
assert all(dilation[i] >= 0 for i in range(ndim))
|
89 |
+
if not transpose:
|
90 |
+
assert all(output_padding[i] == 0 for i in range(ndim))
|
91 |
+
else: # transpose
|
92 |
+
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
93 |
+
|
94 |
+
# Helpers.
|
95 |
+
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
96 |
+
def calc_output_padding(input_shape, output_shape):
|
97 |
+
if transpose:
|
98 |
+
return [0, 0]
|
99 |
+
return [
|
100 |
+
input_shape[i + 2]
|
101 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
102 |
+
- (1 - 2 * padding[i])
|
103 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
104 |
+
for i in range(ndim)
|
105 |
+
]
|
106 |
+
|
107 |
+
# Forward & backward.
|
108 |
+
class Conv2d(torch.autograd.Function):
|
109 |
+
@staticmethod
|
110 |
+
@custom_fwd(cast_inputs=torch.float16)
|
111 |
+
def forward(ctx, input, weight, bias):
|
112 |
+
assert weight.shape == weight_shape
|
113 |
+
if not transpose:
|
114 |
+
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
115 |
+
else: # transpose
|
116 |
+
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
117 |
+
ctx.save_for_backward(input, weight)
|
118 |
+
return output
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
@custom_bwd
|
122 |
+
def backward(ctx, grad_output):
|
123 |
+
input, weight = ctx.saved_tensors
|
124 |
+
grad_input = None
|
125 |
+
grad_weight = None
|
126 |
+
grad_bias = None
|
127 |
+
|
128 |
+
if ctx.needs_input_grad[0]:
|
129 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
130 |
+
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output.float(), weight.float(), None)
|
131 |
+
assert grad_input.shape == input.shape
|
132 |
+
|
133 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
134 |
+
grad_weight = Conv2dGradWeight.apply(grad_output.float(), input.float())
|
135 |
+
assert grad_weight.shape == weight_shape
|
136 |
+
|
137 |
+
if ctx.needs_input_grad[2]:
|
138 |
+
grad_bias = grad_output.float().sum([0, 2, 3])
|
139 |
+
|
140 |
+
return grad_input, grad_weight, grad_bias
|
141 |
+
|
142 |
+
# Gradient with respect to the weights.
|
143 |
+
class Conv2dGradWeight(torch.autograd.Function):
|
144 |
+
@staticmethod
|
145 |
+
@custom_fwd(cast_inputs=torch.float16)
|
146 |
+
def forward(ctx, grad_output, input):
|
147 |
+
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
|
148 |
+
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
149 |
+
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
150 |
+
assert grad_weight.shape == weight_shape
|
151 |
+
ctx.save_for_backward(grad_output, input)
|
152 |
+
return grad_weight
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
@custom_bwd
|
156 |
+
def backward(ctx, grad2_grad_weight):
|
157 |
+
grad_output, input = ctx.saved_tensors
|
158 |
+
grad2_grad_output = None
|
159 |
+
grad2_input = None
|
160 |
+
|
161 |
+
if ctx.needs_input_grad[0]:
|
162 |
+
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
163 |
+
assert grad2_grad_output.shape == grad_output.shape
|
164 |
+
|
165 |
+
if ctx.needs_input_grad[1]:
|
166 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
167 |
+
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
|
168 |
+
assert grad2_input.shape == input.shape
|
169 |
+
|
170 |
+
return grad2_grad_output, grad2_input
|
171 |
+
|
172 |
+
_conv2d_gradfix_cache[key] = Conv2d
|
173 |
+
return Conv2d
|
174 |
+
|
175 |
+
#----------------------------------------------------------------------------
|
sg3_torch_utils/ops/conv2d_resample.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""2D convolution with optional up/downsampling."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .. import misc
|
14 |
+
from . import conv2d_gradfix
|
15 |
+
from . import upfirdn2d
|
16 |
+
from .upfirdn2d import _parse_padding
|
17 |
+
from .upfirdn2d import _get_filter_size
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
def _get_weight_shape(w):
|
22 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
23 |
+
shape = [int(sz) for sz in w.shape]
|
24 |
+
misc.assert_shape(w, shape)
|
25 |
+
return shape
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
30 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
31 |
+
"""
|
32 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
33 |
+
|
34 |
+
# Flip weight if requested.
|
35 |
+
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
36 |
+
w = w.flip([2, 3])
|
37 |
+
|
38 |
+
# Otherwise => execute using conv2d_gradfix.
|
39 |
+
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
40 |
+
return op(x, w, stride=stride, padding=padding, groups=groups)
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
|
44 |
+
@misc.profiled_function
|
45 |
+
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
46 |
+
r"""2D convolution with optional up/downsampling.
|
47 |
+
|
48 |
+
Padding is performed only once at the beginning, not between the operations.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
x: Input tensor of shape
|
52 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
53 |
+
w: Weight tensor of shape
|
54 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
55 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
56 |
+
calling upfirdn2d.setup_filter(). None = identity (default).
|
57 |
+
up: Integer upsampling factor (default: 1).
|
58 |
+
down: Integer downsampling factor (default: 1).
|
59 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
60 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
61 |
+
(default: 0).
|
62 |
+
groups: Split input channels into N groups (default: 1).
|
63 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
64 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
68 |
+
"""
|
69 |
+
# Validate arguments.
|
70 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
71 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4)
|
72 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
73 |
+
assert isinstance(up, int) and (up >= 1)
|
74 |
+
assert isinstance(down, int) and (down >= 1)
|
75 |
+
assert isinstance(groups, int) and (groups >= 1)
|
76 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
77 |
+
fw, fh = _get_filter_size(f)
|
78 |
+
px0, px1, py0, py1 = _parse_padding(padding)
|
79 |
+
|
80 |
+
# Adjust padding to account for up/downsampling.
|
81 |
+
if up > 1:
|
82 |
+
px0 += (fw + up - 1) // 2
|
83 |
+
px1 += (fw - up) // 2
|
84 |
+
py0 += (fh + up - 1) // 2
|
85 |
+
py1 += (fh - up) // 2
|
86 |
+
if down > 1:
|
87 |
+
px0 += (fw - down + 1) // 2
|
88 |
+
px1 += (fw - down) // 2
|
89 |
+
py0 += (fh - down + 1) // 2
|
90 |
+
py1 += (fh - down) // 2
|
91 |
+
|
92 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
93 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
94 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
95 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
96 |
+
return x
|
97 |
+
|
98 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
99 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
100 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
101 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
102 |
+
return x
|
103 |
+
|
104 |
+
# Fast path: downsampling only => use strided convolution.
|
105 |
+
if down > 1 and up == 1:
|
106 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
107 |
+
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
108 |
+
return x
|
109 |
+
|
110 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
111 |
+
if up > 1:
|
112 |
+
if groups == 1:
|
113 |
+
w = w.transpose(0, 1)
|
114 |
+
else:
|
115 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
116 |
+
w = w.transpose(1, 2)
|
117 |
+
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
118 |
+
px0 -= kw - 1
|
119 |
+
px1 -= kw - up
|
120 |
+
py0 -= kh - 1
|
121 |
+
py1 -= kh - up
|
122 |
+
pxt = max(min(-px0, -px1), 0)
|
123 |
+
pyt = max(min(-py0, -py1), 0)
|
124 |
+
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
125 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
126 |
+
if down > 1:
|
127 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
128 |
+
return x
|
129 |
+
|
130 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
131 |
+
if up == 1 and down == 1:
|
132 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
133 |
+
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
134 |
+
|
135 |
+
# Fallback: Generic reference implementation.
|
136 |
+
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
137 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
138 |
+
if down > 1:
|
139 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
140 |
+
return x
|
141 |
+
|
142 |
+
#----------------------------------------------------------------------------
|
sg3_torch_utils/ops/fma.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
13 |
+
|
14 |
+
#----------------------------------------------------------------------------
|
15 |
+
|
16 |
+
def fma(a, b, c): # => a * b + c
|
17 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
22 |
+
@staticmethod
|
23 |
+
@custom_fwd(cast_inputs=torch.float16)
|
24 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
25 |
+
out = torch.addcmul(c, a, b)
|
26 |
+
ctx.save_for_backward(a, b)
|
27 |
+
ctx.c_shape = c.shape
|
28 |
+
return out
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
@custom_bwd
|
32 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
33 |
+
a, b = ctx.saved_tensors
|
34 |
+
c_shape = ctx.c_shape
|
35 |
+
da = None
|
36 |
+
db = None
|
37 |
+
dc = None
|
38 |
+
|
39 |
+
if ctx.needs_input_grad[0]:
|
40 |
+
da = _unbroadcast(dout * b, a.shape)
|
41 |
+
|
42 |
+
if ctx.needs_input_grad[1]:
|
43 |
+
db = _unbroadcast(dout * a, b.shape)
|
44 |
+
|
45 |
+
if ctx.needs_input_grad[2]:
|
46 |
+
dc = _unbroadcast(dout, c_shape)
|
47 |
+
|
48 |
+
return da, db, dc
|
49 |
+
|
50 |
+
#----------------------------------------------------------------------------
|
51 |
+
|
52 |
+
def _unbroadcast(x, shape):
|
53 |
+
extra_dims = x.ndim - len(shape)
|
54 |
+
assert extra_dims >= 0
|
55 |
+
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
56 |
+
if len(dim):
|
57 |
+
x = x.sum(dim=dim, keepdim=True)
|
58 |
+
if extra_dims:
|
59 |
+
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
60 |
+
assert x.shape == shape
|
61 |
+
return x
|
62 |
+
|
63 |
+
#----------------------------------------------------------------------------
|
sg3_torch_utils/ops/grid_sample_gradfix.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
10 |
+
supports arbitrarily high order gradients between the input and output.
|
11 |
+
Only works on 2D images and assumes
|
12 |
+
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
16 |
+
from pkg_resources import parse_version
|
17 |
+
# pylint: disable=redefined-builtin
|
18 |
+
# pylint: disable=arguments-differ
|
19 |
+
# pylint: disable=protected-access
|
20 |
+
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
|
21 |
+
|
22 |
+
|
23 |
+
#----------------------------------------------------------------------------
|
24 |
+
|
25 |
+
enabled = False # Enable the custom op by setting this to true.
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
def grid_sample(input, grid):
|
30 |
+
if _should_use_custom_op():
|
31 |
+
return _GridSample2dForward.apply(input, grid)
|
32 |
+
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
def _should_use_custom_op():
|
37 |
+
return enabled
|
38 |
+
|
39 |
+
#----------------------------------------------------------------------------
|
40 |
+
|
41 |
+
class _GridSample2dForward(torch.autograd.Function):
|
42 |
+
@staticmethod
|
43 |
+
@custom_fwd(cast_inputs=torch.float16)
|
44 |
+
def forward(ctx, input, grid):
|
45 |
+
assert input.ndim == 4
|
46 |
+
assert grid.ndim == 4
|
47 |
+
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
48 |
+
ctx.save_for_backward(input, grid)
|
49 |
+
return output
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
@custom_bwd
|
53 |
+
def backward(ctx, grad_output):
|
54 |
+
input, grid = ctx.saved_tensors
|
55 |
+
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
56 |
+
return grad_input, grad_grid
|
57 |
+
|
58 |
+
#----------------------------------------------------------------------------
|
59 |
+
|
60 |
+
class _GridSample2dBackward(torch.autograd.Function):
|
61 |
+
@staticmethod
|
62 |
+
@custom_fwd(cast_inputs=torch.float16)
|
63 |
+
def forward(ctx, grad_output, input, grid):
|
64 |
+
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
65 |
+
if _use_pytorch_1_11_api:
|
66 |
+
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
|
67 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
|
68 |
+
else:
|
69 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
70 |
+
ctx.save_for_backward(grid)
|
71 |
+
return grad_input, grad_grid
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
@custom_bwd
|
75 |
+
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
76 |
+
_ = grad2_grad_grid # unused
|
77 |
+
grid, = ctx.saved_tensors
|
78 |
+
grad2_grad_output = None
|
79 |
+
grad2_input = None
|
80 |
+
grad2_grid = None
|
81 |
+
|
82 |
+
if ctx.needs_input_grad[0]:
|
83 |
+
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
84 |
+
|
85 |
+
assert not ctx.needs_input_grad[2]
|
86 |
+
return grad2_grad_output, grad2_input, grad2_grid
|
87 |
+
|
88 |
+
#----------------------------------------------------------------------------
|
sg3_torch_utils/ops/upfirdn2d.cpp
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "upfirdn2d.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
17 |
+
{
|
18 |
+
// Validate arguments.
|
19 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
20 |
+
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
21 |
+
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
22 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
23 |
+
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
24 |
+
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
25 |
+
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
26 |
+
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
27 |
+
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
28 |
+
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
29 |
+
|
30 |
+
// Create output tensor.
|
31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
32 |
+
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
33 |
+
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
34 |
+
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
35 |
+
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
36 |
+
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
37 |
+
|
38 |
+
// Initialize CUDA kernel parameters.
|
39 |
+
upfirdn2d_kernel_params p;
|
40 |
+
p.x = x.data_ptr();
|
41 |
+
p.f = f.data_ptr<float>();
|
42 |
+
p.y = y.data_ptr();
|
43 |
+
p.up = make_int2(upx, upy);
|
44 |
+
p.down = make_int2(downx, downy);
|
45 |
+
p.pad0 = make_int2(padx0, pady0);
|
46 |
+
p.flip = (flip) ? 1 : 0;
|
47 |
+
p.gain = gain;
|
48 |
+
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
49 |
+
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
50 |
+
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
51 |
+
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
52 |
+
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
53 |
+
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
54 |
+
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
55 |
+
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
56 |
+
|
57 |
+
// Choose CUDA kernel.
|
58 |
+
upfirdn2d_kernel_spec spec;
|
59 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
60 |
+
{
|
61 |
+
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
62 |
+
});
|
63 |
+
|
64 |
+
// Set looping options.
|
65 |
+
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
66 |
+
p.loopMinor = spec.loopMinor;
|
67 |
+
p.loopX = spec.loopX;
|
68 |
+
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
69 |
+
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
70 |
+
|
71 |
+
// Compute grid size.
|
72 |
+
dim3 blockSize, gridSize;
|
73 |
+
if (spec.tileOutW < 0) // large
|
74 |
+
{
|
75 |
+
blockSize = dim3(4, 32, 1);
|
76 |
+
gridSize = dim3(
|
77 |
+
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
78 |
+
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
79 |
+
p.launchMajor);
|
80 |
+
}
|
81 |
+
else // small
|
82 |
+
{
|
83 |
+
blockSize = dim3(256, 1, 1);
|
84 |
+
gridSize = dim3(
|
85 |
+
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
86 |
+
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
87 |
+
p.launchMajor);
|
88 |
+
}
|
89 |
+
|
90 |
+
// Launch CUDA kernel.
|
91 |
+
void* args[] = {&p};
|
92 |
+
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
93 |
+
return y;
|
94 |
+
}
|
95 |
+
|
96 |
+
//------------------------------------------------------------------------
|
97 |
+
|
98 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
99 |
+
{
|
100 |
+
m.def("upfirdn2d", &upfirdn2d);
|
101 |
+
}
|
102 |
+
|
103 |
+
//------------------------------------------------------------------------
|
sg3_torch_utils/ops/upfirdn2d.cu
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "upfirdn2d.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
static __device__ __forceinline__ int floor_div(int a, int b)
|
21 |
+
{
|
22 |
+
int t = 1 - a / b;
|
23 |
+
return (a + t * b) / b - t;
|
24 |
+
}
|
25 |
+
|
26 |
+
//------------------------------------------------------------------------
|
27 |
+
// Generic CUDA implementation for large filters.
|
28 |
+
|
29 |
+
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
30 |
+
{
|
31 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
32 |
+
|
33 |
+
// Calculate thread index.
|
34 |
+
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
35 |
+
int outY = minorBase / p.launchMinor;
|
36 |
+
minorBase -= outY * p.launchMinor;
|
37 |
+
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
38 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
39 |
+
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
40 |
+
return;
|
41 |
+
|
42 |
+
// Setup Y receptive field.
|
43 |
+
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
44 |
+
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
45 |
+
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
46 |
+
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
47 |
+
if (p.flip)
|
48 |
+
filterY = p.filterSize.y - 1 - filterY;
|
49 |
+
|
50 |
+
// Loop over major, minor, and X.
|
51 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
52 |
+
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
53 |
+
{
|
54 |
+
int nc = major * p.sizeMinor + minor;
|
55 |
+
int n = nc / p.inSize.z;
|
56 |
+
int c = nc - n * p.inSize.z;
|
57 |
+
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
58 |
+
{
|
59 |
+
// Setup X receptive field.
|
60 |
+
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
61 |
+
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
62 |
+
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
63 |
+
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
64 |
+
if (p.flip)
|
65 |
+
filterX = p.filterSize.x - 1 - filterX;
|
66 |
+
|
67 |
+
// Initialize pointers.
|
68 |
+
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
69 |
+
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
70 |
+
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
71 |
+
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
72 |
+
|
73 |
+
// Inner loop.
|
74 |
+
scalar_t v = 0;
|
75 |
+
for (int y = 0; y < h; y++)
|
76 |
+
{
|
77 |
+
for (int x = 0; x < w; x++)
|
78 |
+
{
|
79 |
+
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
80 |
+
xp += p.inStride.x;
|
81 |
+
fp += filterStepX;
|
82 |
+
}
|
83 |
+
xp += p.inStride.y - w * p.inStride.x;
|
84 |
+
fp += filterStepY - w * filterStepX;
|
85 |
+
}
|
86 |
+
|
87 |
+
// Store result.
|
88 |
+
v *= p.gain;
|
89 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
//------------------------------------------------------------------------
|
95 |
+
// Specialized CUDA implementation for small filters.
|
96 |
+
|
97 |
+
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
98 |
+
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
99 |
+
{
|
100 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
101 |
+
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
102 |
+
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
103 |
+
__shared__ volatile scalar_t sf[filterH][filterW];
|
104 |
+
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
105 |
+
|
106 |
+
// Calculate tile index.
|
107 |
+
int minorBase = blockIdx.x;
|
108 |
+
int tileOutY = minorBase / p.launchMinor;
|
109 |
+
minorBase -= tileOutY * p.launchMinor;
|
110 |
+
minorBase *= loopMinor;
|
111 |
+
tileOutY *= tileOutH;
|
112 |
+
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
113 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
114 |
+
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
115 |
+
return;
|
116 |
+
|
117 |
+
// Load filter (flipped).
|
118 |
+
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
119 |
+
{
|
120 |
+
int fy = tapIdx / filterW;
|
121 |
+
int fx = tapIdx - fy * filterW;
|
122 |
+
scalar_t v = 0;
|
123 |
+
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
124 |
+
{
|
125 |
+
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
126 |
+
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
127 |
+
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
128 |
+
}
|
129 |
+
sf[fy][fx] = v;
|
130 |
+
}
|
131 |
+
|
132 |
+
// Loop over major and X.
|
133 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
134 |
+
{
|
135 |
+
int baseNC = major * p.sizeMinor + minorBase;
|
136 |
+
int n = baseNC / p.inSize.z;
|
137 |
+
int baseC = baseNC - n * p.inSize.z;
|
138 |
+
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
139 |
+
{
|
140 |
+
// Load input pixels.
|
141 |
+
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
142 |
+
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
143 |
+
int tileInX = floor_div(tileMidX, upx);
|
144 |
+
int tileInY = floor_div(tileMidY, upy);
|
145 |
+
__syncthreads();
|
146 |
+
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
147 |
+
{
|
148 |
+
int relC = inIdx;
|
149 |
+
int relInX = relC / loopMinor;
|
150 |
+
int relInY = relInX / tileInW;
|
151 |
+
relC -= relInX * loopMinor;
|
152 |
+
relInX -= relInY * tileInW;
|
153 |
+
int c = baseC + relC;
|
154 |
+
int inX = tileInX + relInX;
|
155 |
+
int inY = tileInY + relInY;
|
156 |
+
scalar_t v = 0;
|
157 |
+
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
158 |
+
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
159 |
+
sx[relInY][relInX][relC] = v;
|
160 |
+
}
|
161 |
+
|
162 |
+
// Loop over output pixels.
|
163 |
+
__syncthreads();
|
164 |
+
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
165 |
+
{
|
166 |
+
int relC = outIdx;
|
167 |
+
int relOutX = relC / loopMinor;
|
168 |
+
int relOutY = relOutX / tileOutW;
|
169 |
+
relC -= relOutX * loopMinor;
|
170 |
+
relOutX -= relOutY * tileOutW;
|
171 |
+
int c = baseC + relC;
|
172 |
+
int outX = tileOutX + relOutX;
|
173 |
+
int outY = tileOutY + relOutY;
|
174 |
+
|
175 |
+
// Setup receptive field.
|
176 |
+
int midX = tileMidX + relOutX * downx;
|
177 |
+
int midY = tileMidY + relOutY * downy;
|
178 |
+
int inX = floor_div(midX, upx);
|
179 |
+
int inY = floor_div(midY, upy);
|
180 |
+
int relInX = inX - tileInX;
|
181 |
+
int relInY = inY - tileInY;
|
182 |
+
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
183 |
+
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
184 |
+
|
185 |
+
// Inner loop.
|
186 |
+
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
187 |
+
{
|
188 |
+
scalar_t v = 0;
|
189 |
+
#pragma unroll
|
190 |
+
for (int y = 0; y < filterH / upy; y++)
|
191 |
+
#pragma unroll
|
192 |
+
for (int x = 0; x < filterW / upx; x++)
|
193 |
+
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
194 |
+
v *= p.gain;
|
195 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
199 |
+
}
|
200 |
+
}
|
201 |
+
|
202 |
+
//------------------------------------------------------------------------
|
203 |
+
// CUDA kernel selection.
|
204 |
+
|
205 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
206 |
+
{
|
207 |
+
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
208 |
+
|
209 |
+
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
210 |
+
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
211 |
+
|
212 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
213 |
+
{
|
214 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
215 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
216 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
217 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
218 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
219 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
220 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
221 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
222 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
223 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
224 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
225 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
226 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
227 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
228 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
229 |
+
}
|
230 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
231 |
+
{
|
232 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
233 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
234 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
235 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
236 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
237 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
238 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
239 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
240 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
241 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
242 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
243 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
244 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
245 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
246 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
247 |
+
}
|
248 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
249 |
+
{
|
250 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
251 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
252 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
253 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
254 |
+
}
|
255 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
256 |
+
{
|
257 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
258 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
259 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
260 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
261 |
+
}
|
262 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
263 |
+
{
|
264 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
265 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
266 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
267 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
268 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
269 |
+
}
|
270 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
271 |
+
{
|
272 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
273 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
274 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
275 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
276 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
277 |
+
}
|
278 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
279 |
+
{
|
280 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
281 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
282 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
283 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
284 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
285 |
+
}
|
286 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
287 |
+
{
|
288 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
289 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
290 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
291 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
292 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
293 |
+
}
|
294 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
|
295 |
+
{
|
296 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
297 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
298 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
299 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
300 |
+
}
|
301 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
|
302 |
+
{
|
303 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
304 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
305 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
306 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
307 |
+
}
|
308 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
|
309 |
+
{
|
310 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
311 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
|
312 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
313 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
|
314 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
315 |
+
}
|
316 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
|
317 |
+
{
|
318 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
319 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
|
320 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
321 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
|
322 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
323 |
+
}
|
324 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
|
325 |
+
{
|
326 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
327 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
|
328 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
329 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
|
330 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
331 |
+
}
|
332 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
|
333 |
+
{
|
334 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
335 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
|
336 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
337 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
|
338 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
339 |
+
}
|
340 |
+
return spec;
|
341 |
+
}
|
342 |
+
|
343 |
+
//------------------------------------------------------------------------
|
344 |
+
// Template specializations.
|
345 |
+
|
346 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
347 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
348 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
349 |
+
|
350 |
+
//------------------------------------------------------------------------
|
sg3_torch_utils/ops/upfirdn2d.h
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
//------------------------------------------------------------------------
|
12 |
+
// CUDA kernel parameters.
|
13 |
+
|
14 |
+
struct upfirdn2d_kernel_params
|
15 |
+
{
|
16 |
+
const void* x;
|
17 |
+
const float* f;
|
18 |
+
void* y;
|
19 |
+
|
20 |
+
int2 up;
|
21 |
+
int2 down;
|
22 |
+
int2 pad0;
|
23 |
+
int flip;
|
24 |
+
float gain;
|
25 |
+
|
26 |
+
int4 inSize; // [width, height, channel, batch]
|
27 |
+
int4 inStride;
|
28 |
+
int2 filterSize; // [width, height]
|
29 |
+
int2 filterStride;
|
30 |
+
int4 outSize; // [width, height, channel, batch]
|
31 |
+
int4 outStride;
|
32 |
+
int sizeMinor;
|
33 |
+
int sizeMajor;
|
34 |
+
|
35 |
+
int loopMinor;
|
36 |
+
int loopMajor;
|
37 |
+
int loopX;
|
38 |
+
int launchMinor;
|
39 |
+
int launchMajor;
|
40 |
+
};
|
41 |
+
|
42 |
+
//------------------------------------------------------------------------
|
43 |
+
// CUDA kernel specialization.
|
44 |
+
|
45 |
+
struct upfirdn2d_kernel_spec
|
46 |
+
{
|
47 |
+
void* kernel;
|
48 |
+
int tileOutW;
|
49 |
+
int tileOutH;
|
50 |
+
int loopMinor;
|
51 |
+
int loopX;
|
52 |
+
};
|
53 |
+
|
54 |
+
//------------------------------------------------------------------------
|
55 |
+
// CUDA kernel selection.
|
56 |
+
|
57 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
58 |
+
|
59 |
+
//------------------------------------------------------------------------
|
sg3_torch_utils/ops/upfirdn2d.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import traceback
|
16 |
+
|
17 |
+
from .. import custom_ops
|
18 |
+
from .. import misc
|
19 |
+
from . import conv2d_gradfix
|
20 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
21 |
+
|
22 |
+
#----------------------------------------------------------------------------
|
23 |
+
|
24 |
+
_inited = False
|
25 |
+
_plugin = None
|
26 |
+
enabled = False
|
27 |
+
|
28 |
+
def _init():
|
29 |
+
global _inited, _plugin
|
30 |
+
if not _inited:
|
31 |
+
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
|
32 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
33 |
+
try:
|
34 |
+
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
35 |
+
except:
|
36 |
+
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
37 |
+
return _plugin is not None
|
38 |
+
|
39 |
+
def _parse_scaling(scaling):
|
40 |
+
if isinstance(scaling, int):
|
41 |
+
scaling = [scaling, scaling]
|
42 |
+
assert isinstance(scaling, (list, tuple))
|
43 |
+
assert all(isinstance(x, int) for x in scaling)
|
44 |
+
sx, sy = scaling
|
45 |
+
assert sx >= 1 and sy >= 1
|
46 |
+
return sx, sy
|
47 |
+
|
48 |
+
def _parse_padding(padding):
|
49 |
+
if isinstance(padding, int):
|
50 |
+
padding = [padding, padding]
|
51 |
+
assert isinstance(padding, (list, tuple))
|
52 |
+
assert all(isinstance(x, int) for x in padding)
|
53 |
+
if len(padding) == 2:
|
54 |
+
padx, pady = padding
|
55 |
+
padding = [padx, padx, pady, pady]
|
56 |
+
padx0, padx1, pady0, pady1 = padding
|
57 |
+
return padx0, padx1, pady0, pady1
|
58 |
+
|
59 |
+
def _get_filter_size(f):
|
60 |
+
if f is None:
|
61 |
+
return 1, 1
|
62 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
63 |
+
fw = f.shape[-1]
|
64 |
+
fh = f.shape[0]
|
65 |
+
with misc.suppress_tracer_warnings():
|
66 |
+
fw = int(fw)
|
67 |
+
fh = int(fh)
|
68 |
+
misc.assert_shape(f, [fh, fw][:f.ndim])
|
69 |
+
assert fw >= 1 and fh >= 1
|
70 |
+
return fw, fh
|
71 |
+
|
72 |
+
#----------------------------------------------------------------------------
|
73 |
+
|
74 |
+
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
75 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
f: Torch tensor, numpy array, or python list of the shape
|
79 |
+
`[filter_height, filter_width]` (non-separable),
|
80 |
+
`[filter_taps]` (separable),
|
81 |
+
`[]` (impulse), or
|
82 |
+
`None` (identity).
|
83 |
+
device: Result device (default: cpu).
|
84 |
+
normalize: Normalize the filter so that it retains the magnitude
|
85 |
+
for constant input signal (DC)? (default: True).
|
86 |
+
flip_filter: Flip the filter? (default: False).
|
87 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
88 |
+
separable: Return a separable filter? (default: select automatically).
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Float32 tensor of the shape
|
92 |
+
`[filter_height, filter_width]` (non-separable) or
|
93 |
+
`[filter_taps]` (separable).
|
94 |
+
"""
|
95 |
+
# Validate.
|
96 |
+
if f is None:
|
97 |
+
f = 1
|
98 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
99 |
+
assert f.ndim in [0, 1, 2]
|
100 |
+
assert f.numel() > 0
|
101 |
+
if f.ndim == 0:
|
102 |
+
f = f[np.newaxis]
|
103 |
+
|
104 |
+
# Separable?
|
105 |
+
if separable is None:
|
106 |
+
separable = (f.ndim == 1 and f.numel() >= 8)
|
107 |
+
if f.ndim == 1 and not separable:
|
108 |
+
f = f.ger(f)
|
109 |
+
assert f.ndim == (1 if separable else 2)
|
110 |
+
|
111 |
+
# Apply normalize, flip, gain, and device.
|
112 |
+
if normalize:
|
113 |
+
f /= f.sum()
|
114 |
+
if flip_filter:
|
115 |
+
f = f.flip(list(range(f.ndim)))
|
116 |
+
f = f * (gain ** (f.ndim / 2))
|
117 |
+
f = f.to(device=device)
|
118 |
+
return f
|
119 |
+
|
120 |
+
#----------------------------------------------------------------------------
|
121 |
+
|
122 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
123 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
124 |
+
|
125 |
+
Performs the following sequence of operations for each channel:
|
126 |
+
|
127 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
128 |
+
|
129 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
130 |
+
Negative padding corresponds to cropping the image.
|
131 |
+
|
132 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
133 |
+
so that the footprint of all output pixels lies within the input image.
|
134 |
+
|
135 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
136 |
+
|
137 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
138 |
+
The fused op is considerably more efficient than performing the same calculation
|
139 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x: Float32/float64/float16 input tensor of the shape
|
143 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
144 |
+
f: Float32 FIR filter of the shape
|
145 |
+
`[filter_height, filter_width]` (non-separable),
|
146 |
+
`[filter_taps]` (separable), or
|
147 |
+
`None` (identity).
|
148 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
149 |
+
`[x, y]` (default: 1).
|
150 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
151 |
+
`[x, y]` (default: 1).
|
152 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
153 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
154 |
+
(default: 0).
|
155 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
156 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
157 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
161 |
+
"""
|
162 |
+
assert isinstance(x, torch.Tensor)
|
163 |
+
assert impl in ['ref', 'cuda']
|
164 |
+
if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init():
|
165 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
166 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
167 |
+
|
168 |
+
#----------------------------------------------------------------------------
|
169 |
+
|
170 |
+
@misc.profiled_function
|
171 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
172 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
173 |
+
"""
|
174 |
+
# Validate arguments.
|
175 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
176 |
+
if f is None:
|
177 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
178 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
179 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
180 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
181 |
+
upx, upy = _parse_scaling(up)
|
182 |
+
downx, downy = _parse_scaling(down)
|
183 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
184 |
+
|
185 |
+
# Upsample by inserting zeros.
|
186 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
187 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
188 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
189 |
+
|
190 |
+
# Pad or crop.
|
191 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
192 |
+
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
193 |
+
|
194 |
+
# Setup filter.
|
195 |
+
f = f * (gain ** (f.ndim / 2))
|
196 |
+
f = f.to(x.dtype)
|
197 |
+
if not flip_filter:
|
198 |
+
f = f.flip(list(range(f.ndim)))
|
199 |
+
|
200 |
+
# Convolve with the filter.
|
201 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
202 |
+
if f.ndim == 4:
|
203 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
204 |
+
else:
|
205 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
206 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
207 |
+
|
208 |
+
# Downsample by throwing away pixels.
|
209 |
+
x = x[:, :, ::downy, ::downx]
|
210 |
+
return x
|
211 |
+
|
212 |
+
#----------------------------------------------------------------------------
|
213 |
+
|
214 |
+
_upfirdn2d_cuda_cache = dict()
|
215 |
+
|
216 |
+
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
217 |
+
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
218 |
+
"""
|
219 |
+
# Parse arguments.
|
220 |
+
upx, upy = _parse_scaling(up)
|
221 |
+
downx, downy = _parse_scaling(down)
|
222 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
223 |
+
|
224 |
+
# Lookup from cache.
|
225 |
+
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
226 |
+
if key in _upfirdn2d_cuda_cache:
|
227 |
+
return _upfirdn2d_cuda_cache[key]
|
228 |
+
|
229 |
+
# Forward op.
|
230 |
+
class Upfirdn2dCuda(torch.autograd.Function):
|
231 |
+
@staticmethod
|
232 |
+
@custom_fwd(cast_inputs=torch.float32)
|
233 |
+
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
234 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
235 |
+
if f is None:
|
236 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
237 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
238 |
+
y = x
|
239 |
+
if f.ndim == 2:
|
240 |
+
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
241 |
+
else:
|
242 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
|
243 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
|
244 |
+
ctx.save_for_backward(f)
|
245 |
+
ctx.x_shape = x.shape
|
246 |
+
return y
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
@custom_bwd
|
250 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
251 |
+
f, = ctx.saved_tensors
|
252 |
+
_, _, ih, iw = ctx.x_shape
|
253 |
+
_, _, oh, ow = dy.shape
|
254 |
+
fw, fh = _get_filter_size(f)
|
255 |
+
p = [
|
256 |
+
fw - padx0 - 1,
|
257 |
+
iw * upx - ow * downx + padx0 - upx + 1,
|
258 |
+
fh - pady0 - 1,
|
259 |
+
ih * upy - oh * downy + pady0 - upy + 1,
|
260 |
+
]
|
261 |
+
dx = None
|
262 |
+
df = None
|
263 |
+
|
264 |
+
if ctx.needs_input_grad[0]:
|
265 |
+
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
266 |
+
|
267 |
+
assert not ctx.needs_input_grad[1]
|
268 |
+
return dx, df
|
269 |
+
|
270 |
+
# Add to cache.
|
271 |
+
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
272 |
+
return Upfirdn2dCuda
|
273 |
+
|
274 |
+
#----------------------------------------------------------------------------
|
275 |
+
|
276 |
+
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
277 |
+
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
278 |
+
|
279 |
+
By default, the result is padded so that its shape matches the input.
|
280 |
+
User-specified padding is applied on top of that, with negative values
|
281 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
x: Float32/float64/float16 input tensor of the shape
|
285 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
286 |
+
f: Float32 FIR filter of the shape
|
287 |
+
`[filter_height, filter_width]` (non-separable),
|
288 |
+
`[filter_taps]` (separable), or
|
289 |
+
`None` (identity).
|
290 |
+
padding: Padding with respect to the output. Can be a single number or a
|
291 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
292 |
+
(default: 0).
|
293 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
294 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
295 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
299 |
+
"""
|
300 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
301 |
+
fw, fh = _get_filter_size(f)
|
302 |
+
p = [
|
303 |
+
padx0 + fw // 2,
|
304 |
+
padx1 + (fw - 1) // 2,
|
305 |
+
pady0 + fh // 2,
|
306 |
+
pady1 + (fh - 1) // 2,
|
307 |
+
]
|
308 |
+
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
309 |
+
|
310 |
+
#----------------------------------------------------------------------------
|
311 |
+
|
312 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
313 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
314 |
+
|
315 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
316 |
+
User-specified padding is applied on top of that, with negative values
|
317 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
x: Float32/float64/float16 input tensor of the shape
|
321 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
322 |
+
f: Float32 FIR filter of the shape
|
323 |
+
`[filter_height, filter_width]` (non-separable),
|
324 |
+
`[filter_taps]` (separable), or
|
325 |
+
`None` (identity).
|
326 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
327 |
+
`[x, y]` (default: 1).
|
328 |
+
padding: Padding with respect to the output. Can be a single number or a
|
329 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
330 |
+
(default: 0).
|
331 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
332 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
333 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
337 |
+
"""
|
338 |
+
upx, upy = _parse_scaling(up)
|
339 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
340 |
+
fw, fh = _get_filter_size(f)
|
341 |
+
p = [
|
342 |
+
padx0 + (fw + upx - 1) // 2,
|
343 |
+
padx1 + (fw - upx) // 2,
|
344 |
+
pady0 + (fh + upy - 1) // 2,
|
345 |
+
pady1 + (fh - upy) // 2,
|
346 |
+
]
|
347 |
+
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
348 |
+
|
349 |
+
#----------------------------------------------------------------------------
|
350 |
+
|
351 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
352 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
353 |
+
|
354 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
355 |
+
User-specified padding is applied on top of that, with negative values
|
356 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
x: Float32/float64/float16 input tensor of the shape
|
360 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
361 |
+
f: Float32 FIR filter of the shape
|
362 |
+
`[filter_height, filter_width]` (non-separable),
|
363 |
+
`[filter_taps]` (separable), or
|
364 |
+
`None` (identity).
|
365 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
366 |
+
`[x, y]` (default: 1).
|
367 |
+
padding: Padding with respect to the input. Can be a single number or a
|
368 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
369 |
+
(default: 0).
|
370 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
371 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
372 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
376 |
+
"""
|
377 |
+
downx, downy = _parse_scaling(down)
|
378 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
379 |
+
fw, fh = _get_filter_size(f)
|
380 |
+
p = [
|
381 |
+
padx0 + (fw - downx + 1) // 2,
|
382 |
+
padx1 + (fw - downx) // 2,
|
383 |
+
pady0 + (fh - downy + 1) // 2,
|
384 |
+
pady1 + (fh - downy) // 2,
|
385 |
+
]
|
386 |
+
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
387 |
+
|
388 |
+
#----------------------------------------------------------------------------
|
stylemc.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Approach: "StyleMC: Multi-Channel Based Fast Text-Guided Image Generation and Manipulation"
|
3 |
+
Original source code:
|
4 |
+
https://github.com/autonomousvision/stylegan_xl/blob/f9be58e98110bd946fcdadef2aac8345466faaf3/run_stylemc.py#
|
5 |
+
Modified by Håkon Hukkelås
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
import tqdm
|
10 |
+
import re
|
11 |
+
import click
|
12 |
+
from dp2 import utils
|
13 |
+
import tops
|
14 |
+
from typing import List, Optional
|
15 |
+
import PIL.Image
|
16 |
+
import imageio
|
17 |
+
from timeit import default_timer as timer
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torchvision.transforms.functional import resize, normalize
|
24 |
+
from dp2.infer import build_trained_generator
|
25 |
+
import clip
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
class AverageMeter(object):
|
30 |
+
"""Computes and stores the average and current value"""
|
31 |
+
def __init__(self, name, fmt=':f'):
|
32 |
+
self.name = name
|
33 |
+
self.fmt = fmt
|
34 |
+
self.reset()
|
35 |
+
|
36 |
+
def reset(self):
|
37 |
+
self.val = 0
|
38 |
+
self.avg = 0
|
39 |
+
self.sum = 0
|
40 |
+
self.count = 0
|
41 |
+
|
42 |
+
def update(self, val, n=1):
|
43 |
+
self.val = val
|
44 |
+
self.sum += val * n
|
45 |
+
self.count += n
|
46 |
+
self.avg = self.sum / self.count
|
47 |
+
|
48 |
+
def __str__(self):
|
49 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
50 |
+
return fmtstr.format(**self.__dict__)
|
51 |
+
|
52 |
+
|
53 |
+
class ProgressMeter(object):
|
54 |
+
def __init__(self, num_batches, meters, prefix=""):
|
55 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
56 |
+
self.meters = meters
|
57 |
+
self.prefix = prefix
|
58 |
+
|
59 |
+
def display(self, batch):
|
60 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
61 |
+
entries += [str(meter) for meter in self.meters]
|
62 |
+
print('\t'.join(entries))
|
63 |
+
|
64 |
+
def _get_batch_fmtstr(self, num_batches):
|
65 |
+
num_digits = len(str(num_batches // 1))
|
66 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
67 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
68 |
+
|
69 |
+
|
70 |
+
def save_image(img, path):
|
71 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
72 |
+
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(path)
|
73 |
+
|
74 |
+
|
75 |
+
def unravel_index(index, shape):
|
76 |
+
out = []
|
77 |
+
for dim in reversed(shape):
|
78 |
+
out.append(index % dim)
|
79 |
+
index = index // dim
|
80 |
+
return tuple(reversed(out))
|
81 |
+
|
82 |
+
|
83 |
+
def num_range(s: str) -> List[int]:
|
84 |
+
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
|
85 |
+
|
86 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
87 |
+
m = range_re.match(s)
|
88 |
+
if m:
|
89 |
+
return list(range(int(m.group(1)), int(m.group(2))+1))
|
90 |
+
vals = s.split(',')
|
91 |
+
return [int(x) for x in vals]
|
92 |
+
|
93 |
+
|
94 |
+
#----------------------------------------------------------------------------
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
def spherical_dist_loss(x, y):
|
99 |
+
x = F.normalize(x, dim=-1)
|
100 |
+
y = F.normalize(y, dim=-1)
|
101 |
+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
102 |
+
|
103 |
+
|
104 |
+
def prompts_dist_loss(x, targets, loss):
|
105 |
+
if len(targets) == 1: # Keeps consistent results vs previous method for single objective guidance
|
106 |
+
return loss(x, targets[0])
|
107 |
+
distances = [loss(x, target) for target in targets]
|
108 |
+
return torch.stack(distances, dim=-1).sum(dim=-1)
|
109 |
+
|
110 |
+
|
111 |
+
def embed_text(model, prompt, device='cuda'):
|
112 |
+
return
|
113 |
+
|
114 |
+
|
115 |
+
#----------------------------------------------------------------------------
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
@torch.cuda.amp.autocast()
|
119 |
+
def generate_edit(
|
120 |
+
G,
|
121 |
+
dl,
|
122 |
+
direction,
|
123 |
+
edit_strength,
|
124 |
+
path,
|
125 |
+
):
|
126 |
+
for it, batch in enumerate(dl):
|
127 |
+
batch["embedding"] = None
|
128 |
+
styles = get_styles(None, G, batch, truncation_value=0)
|
129 |
+
imgs = []
|
130 |
+
grad_changes = [_*edit_strength for _ in [0, 0.25, 0.5, 0.75, 1]]
|
131 |
+
grad_changes = [*[-x for x in grad_changes][::-1], *grad_changes]
|
132 |
+
batch = {k: tops.to_cuda(v) if v is not None else v for k,v in batch.items()}
|
133 |
+
for i, grad_change in enumerate(grad_changes):
|
134 |
+
s = styles + direction*grad_change
|
135 |
+
|
136 |
+
img = G(**batch, s=iter(s))["img"]
|
137 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
|
138 |
+
imgs.append(img[0].to(torch.uint8).cpu().numpy())
|
139 |
+
PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(path + f'{it}.png')
|
140 |
+
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def get_styles(seed, G: torch.nn.Module, batch, truncation_value=1):
|
144 |
+
all_styles = []
|
145 |
+
if seed is None:
|
146 |
+
z = np.random.normal(0, 0, size=(1, G.z_channels))
|
147 |
+
else:
|
148 |
+
z = np.random.RandomState(seed=seed).normal(0, 1, size=(1, G.z_channels))
|
149 |
+
z_idx = np.random.RandomState(seed=seed).randint(0, len(G.style_net.w_centers))
|
150 |
+
w_c = G.style_net.w_centers[z_idx].to(tops.get_device()).view(1, -1)
|
151 |
+
w = G.style_net(torch.from_numpy(z).to(tops.get_device()))
|
152 |
+
|
153 |
+
w = w_c.to(w.dtype).lerp(w, truncation_value)
|
154 |
+
if hasattr(G, "get_comod_y"):
|
155 |
+
w = G.get_comod_y(batch, w)
|
156 |
+
for block in G.modules():
|
157 |
+
if not hasattr(block, "affine") or not hasattr(block.affine, "weight"):
|
158 |
+
continue
|
159 |
+
gamma0 = block.affine(w)
|
160 |
+
if hasattr(block, "affine_beta"):
|
161 |
+
beta0 = block.affine_beta(w)
|
162 |
+
gamma0 = torch.cat((gamma0, beta0), dim=1)
|
163 |
+
all_styles.append(gamma0)
|
164 |
+
max_ch = max([s.shape[-1] for s in all_styles])
|
165 |
+
all_styles = [F.pad(s, ((0, max_ch - s.shape[-1])), "constant", 0) for s in all_styles]
|
166 |
+
all_styles = torch.cat(all_styles)
|
167 |
+
return all_styles
|
168 |
+
|
169 |
+
def get_and_cache_direction(output_dir: Path, dl_val, G, text_prompt):
|
170 |
+
cache_path = output_dir.joinpath(
|
171 |
+
"stylemc_cache", text_prompt.replace(" ", "_") + ".torch")
|
172 |
+
if cache_path.is_file():
|
173 |
+
print("Loaded cache from:", cache_path)
|
174 |
+
return torch.load(cache_path)
|
175 |
+
direction = find_direction(G, text_prompt, None, dl_val=iter(dl_val))
|
176 |
+
cache_path.parent.mkdir(exist_ok=True, parents=True)
|
177 |
+
torch.save(direction, cache_path)
|
178 |
+
return direction
|
179 |
+
|
180 |
+
@torch.cuda.amp.autocast()
|
181 |
+
def find_direction(
|
182 |
+
G,
|
183 |
+
text_prompt,
|
184 |
+
batches,
|
185 |
+
#layers,
|
186 |
+
n_iterations=128*8,
|
187 |
+
batch_size=8,
|
188 |
+
dl_val=None
|
189 |
+
):
|
190 |
+
time_start = timer()
|
191 |
+
|
192 |
+
clip_model = clip.load("ViT-B/16", device=tops.get_device())[0]
|
193 |
+
|
194 |
+
target = [clip_model.encode_text(clip.tokenize(text_prompt).to(tops.get_device())).float()]
|
195 |
+
all_styles = []
|
196 |
+
if dl_val is not None:
|
197 |
+
first_batch = next(dl_val)
|
198 |
+
else:
|
199 |
+
first_batch = batches[0]
|
200 |
+
first_batch["embedding"] = None if "embedding" not in first_batch else first_batch["embedding"]
|
201 |
+
s = get_styles(0, G, first_batch)
|
202 |
+
# stats tracker
|
203 |
+
cos_sim_track = AverageMeter('cos_sim', ':.4f')
|
204 |
+
norm_track = AverageMeter('norm', ':.4f')
|
205 |
+
n_iterations = n_iterations // batch_size
|
206 |
+
progress = ProgressMeter(n_iterations, [cos_sim_track, norm_track])
|
207 |
+
|
208 |
+
# initalize styles direction
|
209 |
+
direction = torch.zeros(s.shape, device=tops.get_device())
|
210 |
+
direction.requires_grad_()
|
211 |
+
utils.set_requires_grad(G, False)
|
212 |
+
direction_tracker = torch.zeros_like(direction)
|
213 |
+
opt = torch.optim.AdamW([direction], lr=0.05, betas=(0., 0.999), weight_decay=0.25)
|
214 |
+
|
215 |
+
grads = []
|
216 |
+
for seed_idx in tqdm.trange(n_iterations):
|
217 |
+
# forward pass through synthesis network with new styles
|
218 |
+
if seed_idx == 0:
|
219 |
+
batch = first_batch
|
220 |
+
elif dl_val is not None:
|
221 |
+
batch = next(dl_val)
|
222 |
+
batch["embedding"] = None if "embedding" not in batch else batch["embedding"]
|
223 |
+
else:
|
224 |
+
batch = {k: tops.to_cuda(v) if v is not None else v for k, v in batches[seed_idx].items()}
|
225 |
+
styles = get_styles(seed_idx, G, batch) + direction
|
226 |
+
img = G(**batch, s=iter(styles))["img"]
|
227 |
+
batch = {k: v.cpu() if v is not None else v for k, v in batch.items()}
|
228 |
+
# clip loss
|
229 |
+
img = (img + 1)/2
|
230 |
+
img = normalize(img, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
231 |
+
img = resize(img, (224, 224))
|
232 |
+
embeds = clip_model.encode_image(img)
|
233 |
+
cos_sim = prompts_dist_loss(embeds, target, spherical_dist_loss)
|
234 |
+
cos_sim.backward(retain_graph=True)
|
235 |
+
|
236 |
+
# track stats
|
237 |
+
cos_sim_track.update(cos_sim.item())
|
238 |
+
norm_track.update(torch.norm(direction).item())
|
239 |
+
|
240 |
+
if not (seed_idx % batch_size):
|
241 |
+
|
242 |
+
# zeroing out gradients for non-optimized layers
|
243 |
+
#layers_zeroed = torch.tensor([x for x in range(G.num_ws) if not x in layers])
|
244 |
+
#direction.grad[:, layers_zeroed] = 0
|
245 |
+
|
246 |
+
opt.step()
|
247 |
+
grads.append(direction.grad.clone())
|
248 |
+
direction.grad.data.zero_()
|
249 |
+
|
250 |
+
# keep track of gradients over time
|
251 |
+
if seed_idx > 3:
|
252 |
+
direction_tracker[grads[-2] * grads[-1] < 0] += 1
|
253 |
+
|
254 |
+
# plot stats
|
255 |
+
progress.display(seed_idx)
|
256 |
+
|
257 |
+
# throw out fluctuating channels
|
258 |
+
direction = direction.detach()
|
259 |
+
direction[direction_tracker > n_iterations / 4] = 0
|
260 |
+
print(direction)
|
261 |
+
print(f"Time for direction search: {timer() - time_start:.2f} s")
|
262 |
+
return direction
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
@click.command()
|
268 |
+
@click.argument("config_path")
|
269 |
+
@click.argument("input_path")
|
270 |
+
@click.argument("output_path")
|
271 |
+
#@click.option('--layers', type=num_range, help='Restrict the style space to a range of layers. We recommend not to optimize the critically sampled layers (last 3).', required=True)
|
272 |
+
@click.option('--text-prompt', help='Text', type=str, required=True)
|
273 |
+
@click.option('--edit-strength', help='Strength of edit', type=float, required=True)
|
274 |
+
@click.option('--outdir', help='Where to save the output images', type=str, required=True)
|
275 |
+
def stylemc(
|
276 |
+
config_path,
|
277 |
+
#layers: List[int],
|
278 |
+
text_prompt: str,
|
279 |
+
edit_strength: float,
|
280 |
+
outdir: str,
|
281 |
+
):
|
282 |
+
cfg = utils.load_config(config_path)
|
283 |
+
G = build_trained_generator(cfg)
|
284 |
+
cfg.train.batch_size = 1
|
285 |
+
n_iterations = 256
|
286 |
+
dl_val = tops.config.instantiate(cfg.data.val.loader)
|
287 |
+
|
288 |
+
direction = find_direction(G, text_prompt, None, n_iterations=n_iterations, dl_val=iter(dl_val))
|
289 |
+
|
290 |
+
text_prompt = text_prompt.replace(" ", "_")
|
291 |
+
generate_edit(G, input_path, direction, edit_strength, output_path)
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == "__main__":
|
295 |
+
stylemc()
|