jamino30 commited on
Commit
b7a47e5
·
verified ·
1 Parent(s): 980e9a0

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +17 -20
app.py CHANGED
@@ -33,33 +33,30 @@ optimal_settings = {
33
  'Watercolor': (75, False),
34
  }
35
 
36
- def gram_matrix(feature):
37
- batch_size, n_feature_maps, height, width = feature.size()
38
- new_feature = feature.view(batch_size * n_feature_maps, height * width)
39
- return torch.mm(new_feature, new_feature.t())
40
-
41
- cached_style_gram_matrices = {}
42
- for style_name, style_img_path in tqdm(style_options.items(), desc='Computing style gram matrices'):
43
  style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
44
  style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
45
  with torch.no_grad():
46
- style_features_512 = model(style_img_512)
47
- style_features_1024 = model(style_img_1024)
48
- # compute gram matrices
49
- style_gram_matrices_512 = [gram_matrix(f) for f in style_features_512]
50
- style_gram_matrices_1024 = [gram_matrix(f) for f in style_features_1024]
51
- cached_style_gram_matrices[style_name] = (style_gram_matrices_512, style_gram_matrices_1024)
52
- print('Style caching complete.')
53
-
54
- def compute_loss(generated_features, content_features, style_gram_matrices, alpha, beta):
 
 
55
  content_loss = 0
56
  style_loss = 0
57
 
58
- for generated_feature, content_feature, style_gram_matrix in zip(generated_features, content_features, style_gram_matrices):
59
  content_loss += torch.mean((generated_feature - content_feature) ** 2)
60
 
61
  G = gram_matrix(generated_feature)
62
- A = style_gram_matrix
63
 
64
  E_l = ((G - A) ** 2)
65
  w_l = 1 / 5
@@ -94,13 +91,13 @@ def inference(content_image, style_name, style_strength, output_quality, progres
94
  with torch.no_grad():
95
  content_features = model(content_img)
96
 
97
- style_gram_matrices = cached_style_gram_matrices[style_name][0 if img_size == 512 else 1]
98
 
99
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
100
  optimizer.zero_grad()
101
 
102
  generated_features = model(generated_img)
103
- total_loss = compute_loss(generated_features, content_features, style_gram_matrices, alpha, beta)
104
 
105
  total_loss.backward()
106
  optimizer.step()
 
33
  'Watercolor': (75, False),
34
  }
35
 
36
+ cached_style_features = {}
37
+ for style_name, style_img_path in style_options.items():
 
 
 
 
 
38
  style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
39
  style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
40
  with torch.no_grad():
41
+ style_features = (model(style_img_512), model(style_img_1024))
42
+ cached_style_features[style_name] = style_features
43
+
44
+ def gram_matrix(feature):
45
+ batch_size, n_feature_maps, height, width = feature.size()
46
+ return torch.mm(
47
+ feature.view(batch_size * n_feature_maps, height * width),
48
+ feature.view(batch_size * n_feature_maps, height * width).t()
49
+ )
50
+
51
+ def compute_loss(generated_features, content_features, style_features, alpha, beta):
52
  content_loss = 0
53
  style_loss = 0
54
 
55
+ for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
56
  content_loss += torch.mean((generated_feature - content_feature) ** 2)
57
 
58
  G = gram_matrix(generated_feature)
59
+ A = gram_matrix(style_feature)
60
 
61
  E_l = ((G - A) ** 2)
62
  w_l = 1 / 5
 
91
  with torch.no_grad():
92
  content_features = model(content_img)
93
 
94
+ style_features = cached_style_features[style_name][0 if img_size == 512 else 1]
95
 
96
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
97
  optimizer.zero_grad()
98
 
99
  generated_features = model(generated_img)
100
+ total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
101
 
102
  total_loss.backward()
103
  optimizer.step()