CSB261 commited on
Commit
eb7a233
โ€ข
1 Parent(s): 2e784eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -46
app.py CHANGED
@@ -1,13 +1,16 @@
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
 
3
  import spaces
4
  from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
7
  from PIL import Image
 
 
 
8
 
9
  torch.set_float32_matmul_precision(["high", "highest"][0])
10
-
11
  birefnet = AutoModelForImageSegmentation.from_pretrained(
12
  "ZhengPeng7/BiRefNet", trust_remote_code=True
13
  )
@@ -20,65 +23,58 @@ transform_image = transforms.Compose(
20
  ]
21
  )
22
 
23
-
24
  @spaces.GPU
25
  def fn(image):
26
- if image is None or len(image) == 0:
27
- return image, None # ์›๋ณธ ์ด๋ฏธ์ง€๋„ ๋ฐ˜ํ™˜
28
- im = Image.open(image).convert("RGB")
 
29
  image_size = im.size
30
  origin = im.copy()
31
- input_images = transform_image(im).unsqueeze(0).to("cuda")
32
- # Prediction
 
33
  with torch.no_grad():
34
  preds = birefnet(input_images)[-1].sigmoid().cpu()
35
  pred = preds[0].squeeze()
36
  pred_pil = transforms.ToPILImage()(pred)
37
  mask = pred_pil.resize(image_size)
38
- im.putalpha(mask)
39
- return im, origin # ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€์™€ ์›๋ณธ ์ด๋ฏธ์ง€ ๋ฐ˜ํ™˜
40
-
41
 
42
  def save_image(image):
43
- if image is not None:
44
- image.save("output.png")
45
- return "output.png"
46
- return None
 
47
 
 
 
 
 
 
 
 
48
 
49
- with gr.Blocks() as demo:
50
- with gr.Row():
51
- with gr.Column(scale=1):
52
- image = gr.Image(label="Upload an image")
53
- text = gr.Textbox(label="Paste an image URL")
54
- download_button = gr.Button("Download Image")
55
- output_file = gr.File()
56
-
57
- with gr.Column(scale=2):
58
- slider1 = ImageSlider(label="Processed Image", type="pil")
59
- slider2 = ImageSlider(label="Original Image", type="pil")
60
 
61
- # ์ŠคํŽ˜์ด์Šค์— ์žˆ๋Š” ์˜ˆ์ œ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ
62
- example_image1 = "example_images/example1.jpg"
63
- example_image2 = "example_images/example2.jpg"
64
- example_image3 = "example_images/example3.jpg"
65
-
66
- with gr.Tab("Image Upload"):
67
- tab1 = gr.Interface(
68
- fn, inputs=image, outputs=[slider1, output_file],
69
- examples=[example_image1, example_image2, example_image3], api_name="image"
70
- )
71
 
72
- with gr.Tab("Image URL"):
73
- tab2 = gr.Interface(
74
- fn, inputs=text, outputs=[slider2, output_file],
75
- examples=[example_image1, example_image2, example_image3], api_name="text"
76
- )
77
-
78
- def process_download(image):
79
- return save_image(image[0])
80
-
81
- download_button.click(process_download, inputs=slider1, outputs=output_file)
82
 
83
  if __name__ == "__main__":
84
- demo.launch()
 
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
+ from loadimg import load_img
4
  import spaces
5
  from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
8
  from PIL import Image
9
+ import io
10
+ import os
11
+ import tempfile
12
 
13
  torch.set_float32_matmul_precision(["high", "highest"][0])
 
14
  birefnet = AutoModelForImageSegmentation.from_pretrained(
15
  "ZhengPeng7/BiRefNet", trust_remote_code=True
16
  )
 
23
  ]
24
  )
25
 
 
26
  @spaces.GPU
27
  def fn(image):
28
+ if image is None:
29
+ return None, None
30
+ im = load_img(image, output_type="pil")
31
+ im = im.convert("RGB")
32
  image_size = im.size
33
  origin = im.copy()
34
+ image = load_img(im)
35
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
36
+ # ์˜ˆ์ธก
37
  with torch.no_grad():
38
  preds = birefnet(input_images)[-1].sigmoid().cpu()
39
  pred = preds[0].squeeze()
40
  pred_pil = transforms.ToPILImage()(pred)
41
  mask = pred_pil.resize(image_size)
42
+ image.putalpha(mask)
43
+ return image, origin
 
44
 
45
  def save_image(image):
46
+ if image is None:
47
+ return None
48
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
49
+ image.save(temp_file, format="PNG")
50
+ return temp_file.name
51
 
52
+ def process_and_download(input_image):
53
+ result, original = fn(input_image)
54
+ if result is None:
55
+ return None, None
56
+ result_path = save_image(result)
57
+ original_path = save_image(original)
58
+ return [result_path, original_path], result_path
59
 
60
+ image = gr.Image(label="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
61
+ slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath")
62
+ png_output = gr.File(label="PNG ๋‹ค์šด๋กœ๋“œ")
 
 
 
 
 
 
 
 
63
 
64
+ examples = [
65
+ os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ1.png"),
66
+ os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ2.png"),
67
+ os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ3.png")
68
+ ]
 
 
 
 
 
69
 
70
+ demo = gr.Interface(
71
+ process_and_download,
72
+ inputs=image,
73
+ outputs=[slider, png_output],
74
+ examples=examples,
75
+ title="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ",
76
+ description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด BiRefNet ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐฐ๊ฒฝ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
77
+ )
 
 
78
 
79
  if __name__ == "__main__":
80
+ demo.launch()