jamino30 commited on
Commit
e21f7c8
·
verified ·
1 Parent(s): 9549c37

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. app.py +31 -20
  2. generated-all.jpg +0 -0
  3. generated-bg.jpg +0 -0
  4. inference.py +9 -3
app.py CHANGED
@@ -41,7 +41,7 @@ for style_name, style_img_path in style_options.items():
41
 
42
  @spaces.GPU(duration=12)
43
  def run(content_image, style_name, style_strength=5):
44
- yield None, None
45
  content_img, original_size = preprocess_img(content_image, img_size)
46
  content_img = content_img.to(device)
47
 
@@ -55,32 +55,41 @@ def run(content_image, style_name, style_strength=5):
55
 
56
  st = time.time()
57
 
58
- stream_all = torch.cuda.Stream()
59
- stream_bg = torch.cuda.Stream()
 
60
 
61
- def run_inference(apply_to_background, stream):
62
  with torch.cuda.stream(stream):
63
- return inference(
64
- model=model,
65
- segmentation_model=segmentation_model,
66
- content_image=content_img,
67
- style_features=style_features,
68
- lr=lrs[style_strength-1],
69
- apply_to_background=apply_to_background
70
- )
 
 
 
71
 
72
  with ThreadPoolExecutor() as executor:
73
- future_all = executor.submit(run_inference, False, stream_all)
74
- future_bg = executor.submit(run_inference, True, stream_bg)
75
- generated_img_all = future_all.result()
76
- generated_img_bg = future_bg.result()
 
 
 
 
77
 
78
  et = time.time()
79
  print('TIME TAKEN:', et-st)
80
 
81
  yield (
82
  (content_image, postprocess_img(generated_img_all, original_size)),
83
- (content_image, postprocess_img(generated_img_bg, original_size))
 
84
  )
85
 
86
  def set_slider(value):
@@ -115,7 +124,9 @@ with gr.Blocks(css=css) as demo:
115
  with gr.Column():
116
  output_image_all = ImageSlider(position=0.15, label='Styled Image', type='pil', interactive=False, show_download_button=False)
117
  download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
118
- output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
 
 
119
  download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
120
 
121
  def save_image(img_tuple1, img_tuple2):
@@ -132,7 +143,7 @@ with gr.Blocks(css=css) as demo:
132
  submit_button.click(
133
  fn=run,
134
  inputs=[content_image, style_dropdown, style_strength_slider],
135
- outputs=[output_image_all, output_image_background]
136
  ).then(
137
  fn=save_image,
138
  inputs=[output_image_all, output_image_background],
@@ -144,4 +155,4 @@ with gr.Blocks(css=css) as demo:
144
 
145
  demo.queue = False
146
  demo.config['queue'] = False
147
- demo.launch(show_api=False)
 
41
 
42
  @spaces.GPU(duration=12)
43
  def run(content_image, style_name, style_strength=5):
44
+ yield [None] * 3
45
  content_img, original_size = preprocess_img(content_image, img_size)
46
  content_img = content_img.to(device)
47
 
 
55
 
56
  st = time.time()
57
 
58
+ if device == 'cuda':
59
+ stream_all = torch.cuda.Stream()
60
+ stream_bg = torch.cuda.Stream()
61
 
62
+ def run_inference_cuda(apply_to_background, stream):
63
  with torch.cuda.stream(stream):
64
+ return run_inference(apply_to_background)
65
+
66
+ def run_inference(apply_to_background):
67
+ return inference(
68
+ model=model,
69
+ segmentation_model=segmentation_model,
70
+ content_image=content_img,
71
+ style_features=style_features,
72
+ lr=lrs[style_strength-1],
73
+ apply_to_background=apply_to_background
74
+ )
75
 
76
  with ThreadPoolExecutor() as executor:
77
+ if device == 'cuda':
78
+ future_all = executor.submit(run_inference_cuda, False, stream_all)
79
+ future_bg = executor.submit(run_inference_cuda, True, stream_bg)
80
+ else:
81
+ future_all = executor.submit(run_inference, False)
82
+ future_bg = executor.submit(run_inference, True)
83
+ generated_img_all, _ = future_all.result()
84
+ generated_img_bg, bg_ratio = future_bg.result()
85
 
86
  et = time.time()
87
  print('TIME TAKEN:', et-st)
88
 
89
  yield (
90
  (content_image, postprocess_img(generated_img_all, original_size)),
91
+ (content_image, postprocess_img(generated_img_bg, original_size)),
92
+ f'{bg_ratio:.2f}'
93
  )
94
 
95
  def set_slider(value):
 
124
  with gr.Column():
125
  output_image_all = ImageSlider(position=0.15, label='Styled Image', type='pil', interactive=False, show_download_button=False)
126
  download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
127
+ with gr.Group():
128
+ output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
129
+ bg_ratio_label = gr.Label(label='Background Ratio')
130
  download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
131
 
132
  def save_image(img_tuple1, img_tuple2):
 
143
  submit_button.click(
144
  fn=run,
145
  inputs=[content_image, style_dropdown, style_strength_slider],
146
+ outputs=[output_image_all, output_image_background, bg_ratio_label]
147
  ).then(
148
  fn=save_image,
149
  inputs=[output_image_all, output_image_background],
 
155
 
156
  demo.queue = False
157
  demo.config['queue'] = False
158
+ demo.launch(show_api=False)
generated-all.jpg ADDED
generated-bg.jpg ADDED
inference.py CHANGED
@@ -52,13 +52,19 @@ def inference(
52
  with torch.no_grad():
53
  content_features = model(content_image)
54
 
55
- resized_bg_masks = []
 
56
  if apply_to_background:
57
  segmentation_output = segmentation_model(content_image)['out']
58
  segmentation_mask = segmentation_output.argmax(dim=1)
59
 
60
  background_mask = (segmentation_mask == 0).float()
61
  foreground_mask = (segmentation_mask != 0).float()
 
 
 
 
 
62
 
63
  for cf in content_features:
64
  _, _, h_i, w_i = cf.shape
@@ -83,6 +89,6 @@ def inference(
83
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
84
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
85
 
86
- if iter % 10 == 0: print(f'Loss ({iter}):', min_losses[iter])
87
 
88
- return generated_image
 
52
  with torch.no_grad():
53
  content_features = model(content_image)
54
 
55
+ resized_bg_masks = []
56
+ background_ratio = None
57
  if apply_to_background:
58
  segmentation_output = segmentation_model(content_image)['out']
59
  segmentation_mask = segmentation_output.argmax(dim=1)
60
 
61
  background_mask = (segmentation_mask == 0).float()
62
  foreground_mask = (segmentation_mask != 0).float()
63
+
64
+ background_pixel_count = background_mask.sum().item()
65
+ total_pixel_count = segmentation_mask.numel()
66
+ background_ratio = background_pixel_count / total_pixel_count
67
+ print(f'Background Detected: {background_ratio * 100:.2f}%')
68
 
69
  for cf in content_features:
70
  _, _, h_i, w_i = cf.shape
 
89
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
90
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
91
 
92
+ if iter % 10 == 0: print(f'[{'Background' if apply_to_background else 'Image'}] Loss ({iter}):', min_losses[iter])
93
 
94
+ return generated_image, background_ratio