File size: 3,760 Bytes
91d9343
 
 
349bdfb
a9077eb
8d1740c
91d9343
 
 
 
 
a9077eb
91d9343
 
 
a9077eb
 
91d9343
a9077eb
 
 
 
 
 
 
 
 
 
91d9343
a9077eb
349bdfb
 
91d9343
 
 
 
bbcd902
91d9343
 
a9077eb
91d9343
ce6dca2
d3ca146
91d9343
a9077eb
91d9343
349bdfb
06894c7
d3ca146
c5d8238
91d9343
 
 
a9077eb
e21f7c8
 
bbcd902
a9077eb
 
 
8d1740c
e21f7c8
 
 
 
a9077eb
 
 
 
 
246dd82
b9f6209
91d9343
06894c7
349bdfb
a9077eb
 
91d9343
349bdfb
 
 
 
 
 
 
b9f6209
349bdfb
246dd82
 
cf753ac
fa762f9
a9077eb
 
 
 
 
 
349bdfb
 
e21f7c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms.functional import gaussian_blur

def _gram_matrix(feature):
    batch_size, n_feature_maps, height, width = feature.size()
    new_feature = feature.view(batch_size * n_feature_maps, height * width)
    return torch.mm(new_feature, new_feature.t())

def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
    content_loss = 0
    style_loss = 0
    w_l = 1 / len(generated_features)
    
    for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
        content_loss += F.mse_loss(gf, cf)
        
        if resized_bg_masks:
            blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
            masked_gf = gf * blurred_bg_mask
            masked_sf = sf * blurred_bg_mask
            G = _gram_matrix(masked_gf)
            A = _gram_matrix(masked_sf)
        else:
            G = _gram_matrix(gf)
            A = _gram_matrix(sf)
        style_loss += w_l * F.mse_loss(G, A)
        
    total_loss = alpha * content_loss + beta * style_loss
    return content_loss, style_loss, total_loss

def inference(
    *,
    model,
    segmentation_model,
    content_image,
    style_features,
    apply_to_background,
    lr,
    iterations=101,
    optim_caller=optim.AdamW,
    alpha=1,
    beta=1,
):
    writer = SummaryWriter()
    generated_image = content_image.clone().requires_grad_(True)
    optimizer = optim_caller([generated_image], lr=lr)
    min_losses = [float('inf')] * iterations

    with torch.no_grad():
        content_features = model(content_image)

        resized_bg_masks = []
        background_ratio = None
        if apply_to_background:            
            segmentation_output = segmentation_model(content_image)['out']
            segmentation_mask = segmentation_output.argmax(dim=1)
            background_mask = (segmentation_mask == 0).float()
            foreground_mask = 1 - background_mask
            
            background_pixel_count = background_mask.sum().item()
            total_pixel_count = segmentation_mask.numel()
            background_ratio = background_pixel_count / total_pixel_count

            for cf in content_features:
                _, _, h_i, w_i = cf.shape
                bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
                resized_bg_masks.append(bg_mask)
        
    def closure(iter):
        optimizer.zero_grad()
        generated_features = model(generated_image)
        content_loss, style_loss, total_loss = _compute_loss(
            generated_features, content_features, style_features, resized_bg_masks, alpha, beta
        )
        total_loss.backward()
        
        # log loss
        writer.add_scalars(f'style-{"background" if apply_to_background else "image"}', {
            'Loss/content': content_loss.item(),
            'Loss/style': style_loss.item(),
            'Loss/total': total_loss.item()
        }, iter)
        min_losses[iter] = min(min_losses[iter], total_loss.item())
        
        return total_loss
    
    for iter in range(iterations):
        optimizer.step(lambda: closure(iter))

        if apply_to_background:
            with torch.no_grad():
                foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
                generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized

    writer.flush()
    writer.close()
    return generated_image, background_ratio