jamino30 commited on
Commit
57a96a1
1 Parent(s): f04906e

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +34 -9
  2. utils.py +3 -3
app.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  import torch.optim as optim
9
  import gradio as gr
10
 
11
- from utils import load_img, load_img_from_path, save_img
12
  from vgg19 import VGG_19
13
 
14
  if torch.cuda.is_available(): device = 'cuda'
@@ -31,15 +31,15 @@ def inference(content_image, style_image, style_strength, output_quality, progre
31
  print('DATETIME:', datetime.datetime.now())
32
  print('STYLE:', style_image)
33
  img_size = 1024 if output_quality else 512
34
- content_img, original_size = load_img(content_image, img_size)
35
  content_img = content_img.to(device)
36
- style_img = load_img_from_path(style_options[style_image], img_size)[0].to(device)
37
 
38
  print('CONTENT IMG SIZE:', original_size)
39
  print('STYLE STRENGTH:', style_strength)
40
  print('HIGH QUALITY:', output_quality)
41
 
42
- iters = 50
43
  # learning rate determined by input
44
  lr = 0.001 + (0.099 / 99) * (style_strength - 1)
45
  alpha = 1
@@ -49,7 +49,7 @@ def inference(content_image, style_image, style_strength, output_quality, progre
49
  generated_img = content_img.clone().requires_grad_(True)
50
  optimizer = optim.Adam([generated_img], lr=lr)
51
 
52
- for iter in tqdm(range(iters), desc='The magic is happening ✨'):
53
  generated_features = model(generated_img)
54
  content_features = model(content_img)
55
  style_features = model(style_img)
@@ -76,7 +76,7 @@ def inference(content_image, style_image, style_strength, output_quality, progre
76
 
77
  et = time.time()
78
  print('TIME TAKEN:', et-st)
79
- yield save_img(generated_img, original_size)
80
 
81
 
82
  def set_slider(value):
@@ -92,7 +92,7 @@ css = """
92
  with gr.Blocks(css=css) as demo:
93
  gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer</h1>")
94
  with gr.Column(elem_id='container'):
95
- content_and_output = gr.Image(show_label=False, type='pil', sources=['upload'], format='jpg')
96
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
97
  with gr.Accordion('Adjustments', open=False):
98
  with gr.Group():
@@ -103,9 +103,34 @@ with gr.Blocks(css=css) as demo:
103
  high_button = gr.Button('High').click(fn=lambda: set_slider(100), outputs=[style_strength_slider])
104
  with gr.Group():
105
  output_quality = gr.Checkbox(label='More Realistic', info='Note: If unchecked, the resulting image will have a more artistic flair.', value=True)
106
- submit_button = gr.Button('Submit')
 
 
 
 
 
 
 
107
 
108
- submit_button.click(fn=inference, inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality], outputs=[content_and_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  examples = gr.Examples(
111
  examples=[
 
8
  import torch.optim as optim
9
  import gradio as gr
10
 
11
+ from utils import preprocess_img, preprocess_img_from_path, postprocess_img
12
  from vgg19 import VGG_19
13
 
14
  if torch.cuda.is_available(): device = 'cuda'
 
31
  print('DATETIME:', datetime.datetime.now())
32
  print('STYLE:', style_image)
33
  img_size = 1024 if output_quality else 512
34
+ content_img, original_size = preprocess_img(content_image, img_size)
35
  content_img = content_img.to(device)
36
+ style_img = preprocess_img_from_path(style_options[style_image], img_size)[0].to(device)
37
 
38
  print('CONTENT IMG SIZE:', original_size)
39
  print('STYLE STRENGTH:', style_strength)
40
  print('HIGH QUALITY:', output_quality)
41
 
42
+ iters = 1
43
  # learning rate determined by input
44
  lr = 0.001 + (0.099 / 99) * (style_strength - 1)
45
  alpha = 1
 
49
  generated_img = content_img.clone().requires_grad_(True)
50
  optimizer = optim.Adam([generated_img], lr=lr)
51
 
52
+ for _ in tqdm(range(iters), desc='The magic is happening ✨'):
53
  generated_features = model(generated_img)
54
  content_features = model(content_img)
55
  style_features = model(style_img)
 
76
 
77
  et = time.time()
78
  print('TIME TAKEN:', et-st)
79
+ yield postprocess_img(generated_img, original_size)
80
 
81
 
82
  def set_slider(value):
 
92
  with gr.Blocks(css=css) as demo:
93
  gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer</h1>")
94
  with gr.Column(elem_id='container'):
95
+ content_and_output = gr.Image(show_label=False, type='pil', sources=['upload'], format='jpg', show_download_button=False)
96
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
97
  with gr.Accordion('Adjustments', open=False):
98
  with gr.Group():
 
103
  high_button = gr.Button('High').click(fn=lambda: set_slider(100), outputs=[style_strength_slider])
104
  with gr.Group():
105
  output_quality = gr.Checkbox(label='More Realistic', info='Note: If unchecked, the resulting image will have a more artistic flair.', value=True)
106
+
107
+ submit_button = gr.Button('Submit', variant='primary')
108
+ download_button = gr.DownloadButton(label='Download Image', visible=False)
109
+
110
+ def save_generated_image(img):
111
+ output_path = 'generated.jpg'
112
+ img.save(output_path)
113
+ return output_path
114
 
115
+ submit_button.click(
116
+ fn=inference,
117
+ inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality],
118
+ outputs=[content_and_output]
119
+ ).then(
120
+ fn=save_generated_image,
121
+ inputs=[content_and_output],
122
+ outputs=[download_button]
123
+ ).then(
124
+ fn=lambda _: gr.update(visible=True),
125
+ inputs=[],
126
+ outputs=[download_button]
127
+ )
128
+
129
+ content_and_output.change(
130
+ fn=lambda _: gr.update(visible=False),
131
+ inputs=[content_and_output],
132
+ outputs=[download_button]
133
+ )
134
 
135
  examples = gr.Examples(
136
  examples=[
utils.py CHANGED
@@ -3,7 +3,7 @@ from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
 
6
- def load_img(img: Image, img_size):
7
  original_size = img.size
8
 
9
  transform = transforms.Compose([
@@ -13,7 +13,7 @@ def load_img(img: Image, img_size):
13
  img = transform(img).unsqueeze(0)
14
  return img, original_size
15
 
16
- def load_img_from_path(path_to_image, img_size):
17
  img = Image.open(path_to_image)
18
  original_size = img.size
19
 
@@ -24,7 +24,7 @@ def load_img_from_path(path_to_image, img_size):
24
  img = transform(img).unsqueeze(0)
25
  return img, original_size
26
 
27
- def save_img(img, original_size):
28
  img = img.cpu().clone()
29
  img = img.squeeze(0)
30
 
 
3
  import torch
4
  import torchvision.transforms as transforms
5
 
6
+ def preprocess_img(img: Image, img_size):
7
  original_size = img.size
8
 
9
  transform = transforms.Compose([
 
13
  img = transform(img).unsqueeze(0)
14
  return img, original_size
15
 
16
+ def preprocess_img_from_path(path_to_image, img_size):
17
  img = Image.open(path_to_image)
18
  original_size = img.size
19
 
 
24
  img = transform(img).unsqueeze(0)
25
  return img, original_size
26
 
27
+ def postprocess_img(img, original_size):
28
  img = img.cpu().clone()
29
  img = img.squeeze(0)
30