Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -33,33 +33,30 @@ optimal_settings = {
|
|
33 |
'Watercolor': (75, False),
|
34 |
}
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
new_feature = feature.view(batch_size * n_feature_maps, height * width)
|
39 |
-
return torch.mm(new_feature, new_feature.t())
|
40 |
-
|
41 |
-
cached_style_gram_matrices = {}
|
42 |
-
for style_name, style_img_path in tqdm(style_options.items(), desc='Computing style gram matrices'):
|
43 |
style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
|
44 |
style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
|
45 |
with torch.no_grad():
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
55 |
content_loss = 0
|
56 |
style_loss = 0
|
57 |
|
58 |
-
for generated_feature, content_feature,
|
59 |
content_loss += torch.mean((generated_feature - content_feature) ** 2)
|
60 |
|
61 |
G = gram_matrix(generated_feature)
|
62 |
-
A =
|
63 |
|
64 |
E_l = ((G - A) ** 2)
|
65 |
w_l = 1 / 5
|
@@ -94,13 +91,13 @@ def inference(content_image, style_name, style_strength, output_quality, progres
|
|
94 |
with torch.no_grad():
|
95 |
content_features = model(content_img)
|
96 |
|
97 |
-
|
98 |
|
99 |
for _ in tqdm(range(iters), desc='The magic is happening ✨'):
|
100 |
optimizer.zero_grad()
|
101 |
|
102 |
generated_features = model(generated_img)
|
103 |
-
total_loss = compute_loss(generated_features, content_features,
|
104 |
|
105 |
total_loss.backward()
|
106 |
optimizer.step()
|
|
|
33 |
'Watercolor': (75, False),
|
34 |
}
|
35 |
|
36 |
+
cached_style_features = {}
|
37 |
+
for style_name, style_img_path in style_options.items():
|
|
|
|
|
|
|
|
|
|
|
38 |
style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
|
39 |
style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
|
40 |
with torch.no_grad():
|
41 |
+
style_features = (model(style_img_512), model(style_img_1024))
|
42 |
+
cached_style_features[style_name] = style_features
|
43 |
+
|
44 |
+
def gram_matrix(feature):
|
45 |
+
batch_size, n_feature_maps, height, width = feature.size()
|
46 |
+
return torch.mm(
|
47 |
+
feature.view(batch_size * n_feature_maps, height * width),
|
48 |
+
feature.view(batch_size * n_feature_maps, height * width).t()
|
49 |
+
)
|
50 |
+
|
51 |
+
def compute_loss(generated_features, content_features, style_features, alpha, beta):
|
52 |
content_loss = 0
|
53 |
style_loss = 0
|
54 |
|
55 |
+
for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
|
56 |
content_loss += torch.mean((generated_feature - content_feature) ** 2)
|
57 |
|
58 |
G = gram_matrix(generated_feature)
|
59 |
+
A = gram_matrix(style_feature)
|
60 |
|
61 |
E_l = ((G - A) ** 2)
|
62 |
w_l = 1 / 5
|
|
|
91 |
with torch.no_grad():
|
92 |
content_features = model(content_img)
|
93 |
|
94 |
+
style_features = cached_style_features[style_name][0 if img_size == 512 else 1]
|
95 |
|
96 |
for _ in tqdm(range(iters), desc='The magic is happening ✨'):
|
97 |
optimizer.zero_grad()
|
98 |
|
99 |
generated_features = model(generated_img)
|
100 |
+
total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
|
101 |
|
102 |
total_loss.backward()
|
103 |
optimizer.step()
|