File size: 5,408 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md

# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
# Modified from https://github.com/mit-han-lab/data-efficient-gans
import torch
import torch.nn.functional as F


def apply_diff_aug(data, keys, aug_policy, inplace=False, **kwargs):
    r"""Applies differentiable augmentation.
    Args:
        data (dict): Input data.
        keys (list of str): Keys to the data values that we want to apply
            differentiable augmentation to.
        aug_policy (str): Type of augmentation(s), ``'color'``,
            ``'translation'``, or ``'cutout'`` separated by ``','``.
    """
    if aug_policy == '':
        return data
    data_aug = data if inplace else {}
    for key, value in data.items():
        if key in keys:
            data_aug[key] = diff_aug(data[key], aug_policy, **kwargs)
        else:
            data_aug[key] = data[key]
    return data_aug


def diff_aug(x, policy='', channels_first=True, **kwargs):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x, **kwargs)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x, **kwargs):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype,
                        device=x.device) - 0.5)
    return x


def rand_saturation(x, **kwargs):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype,
                                   device=x.device) * 2) + x_mean
    return x


def rand_contrast(x, **kwargs):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype,
                                   device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125, **kwargs):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(
        x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1],
                                  device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1],
                                  device=x.device)
    # noinspection PyTypeChecker
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[
        grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5, **kwargs):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2),
                             size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2),
                             size=[x.size(0), 1, 1], device=x.device)
    # noinspection PyTypeChecker
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0,
                         max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0,
                         max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3),
                      dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


def rand_translation_scale(x, trans_r=0.125, scale_r=0.125,
                           mode='bilinear', padding_mode='reflection',
                           **kwargs):
    assert x.dim() == 4, "Input must be a 4D tensor."
    batch_size = x.size(0)

    # Identity transformation.
    theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat(
        batch_size, 1, 1)

    # Translation, uniformly sampled from (-trans_r, trans_r).
    translate = \
        2 * trans_r * torch.rand(batch_size, 2, device=x.device) - trans_r
    theta[:, :, 2] += translate

    # Scaling, uniformly sampled from (1-scale_r, 1+scale_r).
    scale = \
        2 * scale_r * torch.rand(batch_size, 2, device=x.device) - scale_r
    theta[:, :, :2] += torch.diag_embed(scale)

    grid = F.affine_grid(theta, x.size())
    x = F.grid_sample(
        x.float(), grid.float(), mode=mode, padding_mode=padding_mode)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'translation_scale': [rand_translation_scale],
    'cutout': [rand_cutout],
}