|
from tqdm import tqdm |
|
|
|
import torch |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
|
|
def _gram_matrix(feature): |
|
batch_size, n_feature_maps, height, width = feature.size() |
|
new_feature = feature.view(batch_size * n_feature_maps, height * width) |
|
return torch.mm(new_feature, new_feature.t()) |
|
|
|
def _compute_loss(generated_features, content_features, style_features, alpha, beta): |
|
content_loss = 0 |
|
style_loss = 0 |
|
w_l = 1 / len(generated_features) |
|
for gf, cf, sf in zip(generated_features, content_features, style_features): |
|
content_loss += F.mse_loss(gf, cf) |
|
G = _gram_matrix(gf) |
|
A = _gram_matrix(sf) |
|
style_loss += w_l * F.mse_loss(G, A) |
|
return alpha * content_loss + beta * style_loss |
|
|
|
def inference( |
|
*, |
|
model, |
|
content_image, |
|
style_features, |
|
lr, |
|
iterations=100, |
|
optim_caller=optim.AdamW, |
|
alpha=1, |
|
beta=1 |
|
): |
|
generated_image = content_image.clone().requires_grad_(True) |
|
optimizer = optim_caller([generated_image], lr=lr) |
|
min_losses = [[]] * iterations |
|
|
|
with torch.no_grad(): |
|
content_features = model(content_image) |
|
|
|
def closure(iter): |
|
optimizer.zero_grad() |
|
generated_features = model(generated_image) |
|
total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta) |
|
total_loss.backward() |
|
min_losses[iter] = min(min_losses[iter], total_loss.item()) |
|
return total_loss |
|
|
|
for iter in tqdm(range(iterations), desc='The magic is happening ✨'): |
|
optimizer.step(closure) |
|
print(f'Loss ({iter+1}):', min_losses[iter]) |
|
|
|
return generated_image |