Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -33,30 +33,33 @@ optimal_settings = {
|
|
33 |
'Watercolor': (75, False),
|
34 |
}
|
35 |
|
36 |
-
cached_style_features = {}
|
37 |
-
for style_name, style_img_path in style_options.items():
|
38 |
-
style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
|
39 |
-
style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
|
40 |
-
with torch.no_grad():
|
41 |
-
style_features = (model(style_img_512), model(style_img_1024))
|
42 |
-
cached_style_features[style_name] = style_features
|
43 |
-
|
44 |
def gram_matrix(feature):
|
45 |
batch_size, n_feature_maps, height, width = feature.size()
|
46 |
-
|
47 |
-
|
48 |
-
feature.view(batch_size * n_feature_maps, height * width).t()
|
49 |
-
)
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
content_loss = 0
|
53 |
style_loss = 0
|
54 |
|
55 |
-
for generated_feature, content_feature,
|
56 |
content_loss += torch.mean((generated_feature - content_feature) ** 2)
|
57 |
|
58 |
G = gram_matrix(generated_feature)
|
59 |
-
A =
|
60 |
|
61 |
E_l = ((G - A) ** 2)
|
62 |
w_l = 1 / 5
|
@@ -91,13 +94,13 @@ def inference(content_image, style_name, style_strength, output_quality, progres
|
|
91 |
with torch.no_grad():
|
92 |
content_features = model(content_img)
|
93 |
|
94 |
-
|
95 |
|
96 |
for _ in tqdm(range(iters), desc='The magic is happening ✨'):
|
97 |
optimizer.zero_grad()
|
98 |
|
99 |
generated_features = model(generated_img)
|
100 |
-
total_loss = compute_loss(generated_features, content_features,
|
101 |
|
102 |
total_loss.backward()
|
103 |
optimizer.step()
|
|
|
33 |
'Watercolor': (75, False),
|
34 |
}
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def gram_matrix(feature):
|
37 |
batch_size, n_feature_maps, height, width = feature.size()
|
38 |
+
new_feature = feature.view(batch_size * n_feature_maps, height * width)
|
39 |
+
return torch.mm(new_feature, new_feature.t())
|
|
|
|
|
40 |
|
41 |
+
cached_style_gram_matrices = {}
|
42 |
+
for style_name, style_img_path in tqdm(style_options.items(), desc='Computing style gram matrices'):
|
43 |
+
style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
|
44 |
+
style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
|
45 |
+
with torch.no_grad():
|
46 |
+
style_features_512 = model(style_img_512)
|
47 |
+
style_features_1024 = model(style_img_1024)
|
48 |
+
# compute gram matrices
|
49 |
+
style_gram_matrices_512 = [gram_matrix(f) for f in style_features_512]
|
50 |
+
style_gram_matrices_1024 = [gram_matrix(f) for f in style_features_1024]
|
51 |
+
cached_style_gram_matrices[style_name] = (style_gram_matrices_512, style_gram_matrices_1024)
|
52 |
+
print('Style caching complete.')
|
53 |
+
|
54 |
+
def compute_loss(generated_features, content_features, style_gram_matrices, alpha, beta):
|
55 |
content_loss = 0
|
56 |
style_loss = 0
|
57 |
|
58 |
+
for generated_feature, content_feature, style_gram_matrix in zip(generated_features, content_features, style_gram_matrices):
|
59 |
content_loss += torch.mean((generated_feature - content_feature) ** 2)
|
60 |
|
61 |
G = gram_matrix(generated_feature)
|
62 |
+
A = style_gram_matrix
|
63 |
|
64 |
E_l = ((G - A) ** 2)
|
65 |
w_l = 1 / 5
|
|
|
94 |
with torch.no_grad():
|
95 |
content_features = model(content_img)
|
96 |
|
97 |
+
style_gram_matrices = cached_style_gram_matrices[style_name][0 if img_size == 512 else 1]
|
98 |
|
99 |
for _ in tqdm(range(iters), desc='The magic is happening ✨'):
|
100 |
optimizer.zero_grad()
|
101 |
|
102 |
generated_features = model(generated_img)
|
103 |
+
total_loss = compute_loss(generated_features, content_features, style_gram_matrices, alpha, beta)
|
104 |
|
105 |
total_loss.backward()
|
106 |
optimizer.step()
|