jamino30 commited on
Commit
bbcd902
·
verified ·
1 Parent(s): a9077eb

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +5 -0
  2. inference.py +2 -5
app.py CHANGED
@@ -4,6 +4,7 @@ from datetime import datetime, timezone, timedelta
4
 
5
  import spaces
6
  import torch
 
7
  import numpy as np
8
  import gradio as gr
9
  from gradio_imageslider import ImageSlider
@@ -21,6 +22,9 @@ if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
21
  model = VGG_19().to(device).eval()
22
  for param in model.parameters():
23
  param.requires_grad = False
 
 
 
24
 
25
  style_files = os.listdir('./style_images')
26
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
@@ -51,6 +55,7 @@ def run(content_image, style_name, style_strength=5, apply_to_background=False,
51
  st = time.time()
52
  generated_img = inference(
53
  model=model,
 
54
  content_image=content_img,
55
  style_features=style_features,
56
  lr=lrs[style_strength-1],
 
4
 
5
  import spaces
6
  import torch
7
+ import torchvision.models as models
8
  import numpy as np
9
  import gradio as gr
10
  from gradio_imageslider import ImageSlider
 
22
  model = VGG_19().to(device).eval()
23
  for param in model.parameters():
24
  param.requires_grad = False
25
+ segmentation_model = models.segmentation.deeplabv3_resnet101(
26
+ weights='DEFAULT'
27
+ ).to(device).eval()
28
 
29
  style_files = os.listdir('./style_images')
30
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
 
55
  st = time.time()
56
  generated_img = inference(
57
  model=model,
58
+ segmentation_model=segmentation_model,
59
  content_image=content_img,
60
  style_features=style_features,
61
  lr=lrs[style_strength-1],
inference.py CHANGED
@@ -28,7 +28,6 @@ def _compute_loss(generated_features, content_features, style_features, resized_
28
  else:
29
  G = _gram_matrix(gf)
30
  A = _gram_matrix(sf)
31
- style_loss += w_l * F.mse_loss(G, A)
32
  style_loss += w_l * F.mse_loss(G, A)
33
 
34
  return alpha * content_loss + beta * style_loss
@@ -36,6 +35,7 @@ def _compute_loss(generated_features, content_features, style_features, resized_
36
  def inference(
37
  *,
38
  model,
 
39
  content_image,
40
  style_features,
41
  apply_to_background,
@@ -53,10 +53,7 @@ def inference(
53
  content_features = model(content_image)
54
 
55
  resized_bg_masks = []
56
- if apply_to_background:
57
- segmentation_model = models.segmentation.deeplabv3_resnet101(weights='DEFAULT').eval()
58
- segmentation_model = segmentation_model.to(content_image.device)
59
-
60
  segmentation_output = segmentation_model(content_image)['out']
61
  segmentation_mask = segmentation_output.argmax(dim=1)
62
 
 
28
  else:
29
  G = _gram_matrix(gf)
30
  A = _gram_matrix(sf)
 
31
  style_loss += w_l * F.mse_loss(G, A)
32
 
33
  return alpha * content_loss + beta * style_loss
 
35
  def inference(
36
  *,
37
  model,
38
+ segmentation_model,
39
  content_image,
40
  style_features,
41
  apply_to_background,
 
53
  content_features = model(content_image)
54
 
55
  resized_bg_masks = []
56
+ if apply_to_background:
 
 
 
57
  segmentation_output = segmentation_model(content_image)['out']
58
  segmentation_mask = segmentation_output.argmax(dim=1)
59