1-13-am commited on
Commit
6c74699
1 Parent(s): a5f14de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torchvision
 
3
 
4
  def normalize():
5
  MEAN = [0.485, 0.456, 0.406]
@@ -35,7 +36,7 @@ def tensor_to_img(tensor):
35
  def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., style_img_2 = None, iw_2 = 0., style_img_3 = None, iw_3 = 0., preserve_color = None):
36
  transform = transformer(imsize = 512)
37
 
38
- content = transform(content_img).unsqueeze(0).cuda()
39
 
40
  iw = [iw_1, iw_2, iw_3]
41
  interpolation_weights = [i/ sum(iw) for i in iw]
@@ -44,7 +45,7 @@ def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., s
44
  styles = []
45
  for style_img in style_imgs:
46
  if style_img is not None:
47
- styles.append(transform(style_img).unsqueeze(0).cuda())
48
  if preserve_color == "None": preserve_color = None
49
  elif preserve_color == "Whitening": preserve_color = "batch_wct"
50
  #elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
@@ -80,4 +81,4 @@ interface = gr.Interface(fn = style_transfer,
80
  description = None
81
  )
82
  interface.queue()
83
- interface.launch(share = True, debug = True)
 
1
  import gradio as gr
2
  import torchvision
3
+ import torchvision.transforms as transforms
4
 
5
  def normalize():
6
  MEAN = [0.485, 0.456, 0.406]
 
36
  def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., style_img_2 = None, iw_2 = 0., style_img_3 = None, iw_3 = 0., preserve_color = None):
37
  transform = transformer(imsize = 512)
38
 
39
+ content = transform(content_img).unsqueeze(0)
40
 
41
  iw = [iw_1, iw_2, iw_3]
42
  interpolation_weights = [i/ sum(iw) for i in iw]
 
45
  styles = []
46
  for style_img in style_imgs:
47
  if style_img is not None:
48
+ styles.append(transform(style_img).unsqueeze(0))
49
  if preserve_color == "None": preserve_color = None
50
  elif preserve_color == "Whitening": preserve_color = "batch_wct"
51
  #elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
 
81
  description = None
82
  )
83
  interface.queue()
84
+ interface.launch(share = True)