jamino30 commited on
Commit
d3ca146
·
verified ·
1 Parent(s): 3e75c58

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +7 -15
inference.py CHANGED
@@ -26,33 +26,25 @@ def inference(
26
  content_image,
27
  style_features,
28
  lr,
29
- adam_iterations=1,
30
- lbfgs_iterations=3,
31
  alpha=1,
32
- beta=1,
33
- clip_grad_norm=5.0
34
  ):
35
- torch.manual_seed(42)
36
-
37
  generated_image = content_image.clone().requires_grad_(True)
 
38
 
39
  with torch.no_grad():
40
  content_features = model(content_image)
41
 
42
- def closure(optimizer):
43
  optimizer.zero_grad()
44
  generated_features = model(generated_image)
45
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
46
  total_loss.backward()
47
- torch.nn.utils.clip_grad_norm_([generated_image], max_norm=clip_grad_norm) # clip gradients
48
  return total_loss
49
 
50
- adam_optimizer = optim.AdamW([generated_image], lr=lr)
51
- for _ in tqdm(range(adam_iterations), desc='The magic is happening (1/2) ✨'):
52
- adam_optimizer.step(lambda: closure(adam_optimizer))
53
-
54
- lbfgs_optimizer = optim.LBFGS([generated_image], lr=lr)
55
- for _ in tqdm(range(lbfgs_iterations), desc='The magic is happening (2/2) ✨'):
56
- lbfgs_optimizer.step(lambda: closure(lbfgs_optimizer))
57
 
58
  return generated_image
 
26
  content_image,
27
  style_features,
28
  lr,
29
+ iterations=35,
30
+ optim_caller=optim.AdamW,
31
  alpha=1,
32
+ beta=1
 
33
  ):
 
 
34
  generated_image = content_image.clone().requires_grad_(True)
35
+ optimizer = optim_caller([generated_image], lr=lr)
36
 
37
  with torch.no_grad():
38
  content_features = model(content_image)
39
 
40
+ def closure():
41
  optimizer.zero_grad()
42
  generated_features = model(generated_image)
43
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
44
  total_loss.backward()
 
45
  return total_loss
46
 
47
+ for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
48
+ optimizer.step(closure)
 
 
 
 
 
49
 
50
  return generated_image