jamino30 commited on
Commit
980e9a0
·
verified ·
1 Parent(s): fb14dcf

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +20 -17
app.py CHANGED
@@ -33,30 +33,33 @@ optimal_settings = {
33
  'Watercolor': (75, False),
34
  }
35
 
36
- cached_style_features = {}
37
- for style_name, style_img_path in style_options.items():
38
- style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
39
- style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
40
- with torch.no_grad():
41
- style_features = (model(style_img_512), model(style_img_1024))
42
- cached_style_features[style_name] = style_features
43
-
44
  def gram_matrix(feature):
45
  batch_size, n_feature_maps, height, width = feature.size()
46
- return torch.mm(
47
- feature.view(batch_size * n_feature_maps, height * width),
48
- feature.view(batch_size * n_feature_maps, height * width).t()
49
- )
50
 
51
- def compute_loss(generated_features, content_features, style_features, alpha, beta):
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  content_loss = 0
53
  style_loss = 0
54
 
55
- for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
56
  content_loss += torch.mean((generated_feature - content_feature) ** 2)
57
 
58
  G = gram_matrix(generated_feature)
59
- A = gram_matrix(style_feature)
60
 
61
  E_l = ((G - A) ** 2)
62
  w_l = 1 / 5
@@ -91,13 +94,13 @@ def inference(content_image, style_name, style_strength, output_quality, progres
91
  with torch.no_grad():
92
  content_features = model(content_img)
93
 
94
- style_features = cached_style_features[style_name][0 if img_size == 512 else 1]
95
 
96
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
97
  optimizer.zero_grad()
98
 
99
  generated_features = model(generated_img)
100
- total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
101
 
102
  total_loss.backward()
103
  optimizer.step()
 
33
  'Watercolor': (75, False),
34
  }
35
 
 
 
 
 
 
 
 
 
36
  def gram_matrix(feature):
37
  batch_size, n_feature_maps, height, width = feature.size()
38
+ new_feature = feature.view(batch_size * n_feature_maps, height * width)
39
+ return torch.mm(new_feature, new_feature.t())
 
 
40
 
41
+ cached_style_gram_matrices = {}
42
+ for style_name, style_img_path in tqdm(style_options.items(), desc='Computing style gram matrices'):
43
+ style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
44
+ style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
45
+ with torch.no_grad():
46
+ style_features_512 = model(style_img_512)
47
+ style_features_1024 = model(style_img_1024)
48
+ # compute gram matrices
49
+ style_gram_matrices_512 = [gram_matrix(f) for f in style_features_512]
50
+ style_gram_matrices_1024 = [gram_matrix(f) for f in style_features_1024]
51
+ cached_style_gram_matrices[style_name] = (style_gram_matrices_512, style_gram_matrices_1024)
52
+ print('Style caching complete.')
53
+
54
+ def compute_loss(generated_features, content_features, style_gram_matrices, alpha, beta):
55
  content_loss = 0
56
  style_loss = 0
57
 
58
+ for generated_feature, content_feature, style_gram_matrix in zip(generated_features, content_features, style_gram_matrices):
59
  content_loss += torch.mean((generated_feature - content_feature) ** 2)
60
 
61
  G = gram_matrix(generated_feature)
62
+ A = style_gram_matrix
63
 
64
  E_l = ((G - A) ** 2)
65
  w_l = 1 / 5
 
94
  with torch.no_grad():
95
  content_features = model(content_img)
96
 
97
+ style_gram_matrices = cached_style_gram_matrices[style_name][0 if img_size == 512 else 1]
98
 
99
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
100
  optimizer.zero_grad()
101
 
102
  generated_features = model(generated_img)
103
+ total_loss = compute_loss(generated_features, content_features, style_gram_matrices, alpha, beta)
104
 
105
  total_loss.backward()
106
  optimizer.step()