Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,760 Bytes
91d9343 349bdfb a9077eb 8d1740c 91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 349bdfb 91d9343 bbcd902 91d9343 a9077eb 91d9343 ce6dca2 d3ca146 91d9343 a9077eb 91d9343 349bdfb 06894c7 d3ca146 c5d8238 91d9343 a9077eb e21f7c8 bbcd902 a9077eb 8d1740c e21f7c8 a9077eb 246dd82 b9f6209 91d9343 06894c7 349bdfb a9077eb 91d9343 349bdfb b9f6209 349bdfb 246dd82 cf753ac fa762f9 a9077eb 349bdfb e21f7c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms.functional import gaussian_blur
def _gram_matrix(feature):
batch_size, n_feature_maps, height, width = feature.size()
new_feature = feature.view(batch_size * n_feature_maps, height * width)
return torch.mm(new_feature, new_feature.t())
def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
content_loss = 0
style_loss = 0
w_l = 1 / len(generated_features)
for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
content_loss += F.mse_loss(gf, cf)
if resized_bg_masks:
blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
masked_gf = gf * blurred_bg_mask
masked_sf = sf * blurred_bg_mask
G = _gram_matrix(masked_gf)
A = _gram_matrix(masked_sf)
else:
G = _gram_matrix(gf)
A = _gram_matrix(sf)
style_loss += w_l * F.mse_loss(G, A)
total_loss = alpha * content_loss + beta * style_loss
return content_loss, style_loss, total_loss
def inference(
*,
model,
segmentation_model,
content_image,
style_features,
apply_to_background,
lr,
iterations=101,
optim_caller=optim.AdamW,
alpha=1,
beta=1,
):
writer = SummaryWriter()
generated_image = content_image.clone().requires_grad_(True)
optimizer = optim_caller([generated_image], lr=lr)
min_losses = [float('inf')] * iterations
with torch.no_grad():
content_features = model(content_image)
resized_bg_masks = []
background_ratio = None
if apply_to_background:
segmentation_output = segmentation_model(content_image)['out']
segmentation_mask = segmentation_output.argmax(dim=1)
background_mask = (segmentation_mask == 0).float()
foreground_mask = 1 - background_mask
background_pixel_count = background_mask.sum().item()
total_pixel_count = segmentation_mask.numel()
background_ratio = background_pixel_count / total_pixel_count
for cf in content_features:
_, _, h_i, w_i = cf.shape
bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
resized_bg_masks.append(bg_mask)
def closure(iter):
optimizer.zero_grad()
generated_features = model(generated_image)
content_loss, style_loss, total_loss = _compute_loss(
generated_features, content_features, style_features, resized_bg_masks, alpha, beta
)
total_loss.backward()
# log loss
writer.add_scalars(f'style-{"background" if apply_to_background else "image"}', {
'Loss/content': content_loss.item(),
'Loss/style': style_loss.item(),
'Loss/total': total_loss.item()
}, iter)
min_losses[iter] = min(min_losses[iter], total_loss.item())
return total_loss
for iter in range(iterations):
optimizer.step(lambda: closure(iter))
if apply_to_background:
with torch.no_grad():
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
writer.flush()
writer.close()
return generated_image, background_ratio
|