CSB261 commited on
Commit
7d3d39b
โ€ข
1 Parent(s): 67c9b5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -6,15 +6,16 @@ from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
8
  from PIL import Image
9
- import io
10
  import os
11
  import tempfile
12
 
13
  torch.set_float32_matmul_precision(["high", "highest"][0])
 
14
  birefnet = AutoModelForImageSegmentation.from_pretrained(
15
  "ZhengPeng7/BiRefNet", trust_remote_code=True
16
  )
17
  birefnet.to("cuda")
 
18
  transform_image = transforms.Compose(
19
  [
20
  transforms.Resize((1024, 1024)),
@@ -31,22 +32,21 @@ def fn(image):
31
  im = im.convert("RGB")
32
  image_size = im.size
33
  origin = im.copy()
34
- image = load_img(im)
35
- input_images = transform_image(image).unsqueeze(0).to("cuda")
36
  # ์˜ˆ์ธก
37
  with torch.no_grad():
38
  preds = birefnet(input_images)[-1].sigmoid().cpu()
39
  pred = preds[0].squeeze()
40
  pred_pil = transforms.ToPILImage()(pred)
41
  mask = pred_pil.resize(image_size)
42
- image.putalpha(mask)
43
- return image, origin
44
 
45
  def save_image(image):
46
  if image is None:
47
  return None
48
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
49
- image.save(temp_file, format="PNG")
50
  return temp_file.name
51
 
52
  def process_and_download(input_image):
@@ -57,24 +57,25 @@ def process_and_download(input_image):
57
  original_path = save_image(original)
58
  return [result_path, original_path], result_path
59
 
 
 
 
 
 
 
60
  image = gr.Image(label="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
61
  slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath")
62
  png_output = gr.File(label="PNG ๋‹ค์šด๋กœ๋“œ")
63
 
64
- examples = [
65
- os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ1.png"),
66
- os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ2.png"),
67
- os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ3.png")
68
- ]
69
-
70
  demo = gr.Interface(
71
  process_and_download,
72
  inputs=image,
73
  outputs=[slider, png_output],
74
- examples=examples,
75
  title="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ",
76
  description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด BiRefNet ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐฐ๊ฒฝ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
77
  )
78
 
79
  if __name__ == "__main__":
80
- demo.launch()
 
6
  import torch
7
  from torchvision import transforms
8
  from PIL import Image
 
9
  import os
10
  import tempfile
11
 
12
  torch.set_float32_matmul_precision(["high", "highest"][0])
13
+
14
  birefnet = AutoModelForImageSegmentation.from_pretrained(
15
  "ZhengPeng7/BiRefNet", trust_remote_code=True
16
  )
17
  birefnet.to("cuda")
18
+
19
  transform_image = transforms.Compose(
20
  [
21
  transforms.Resize((1024, 1024)),
 
32
  im = im.convert("RGB")
33
  image_size = im.size
34
  origin = im.copy()
35
+ input_images = transform_image(im).unsqueeze(0).to("cuda")
 
36
  # ์˜ˆ์ธก
37
  with torch.no_grad():
38
  preds = birefnet(input_images)[-1].sigmoid().cpu()
39
  pred = preds[0].squeeze()
40
  pred_pil = transforms.ToPILImage()(pred)
41
  mask = pred_pil.resize(image_size)
42
+ im.putalpha(mask)
43
+ return im, origin
44
 
45
  def save_image(image):
46
  if image is None:
47
  return None
48
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
49
+ image.save(temp_file.name, format="PNG")
50
  return temp_file.name
51
 
52
  def process_and_download(input_image):
 
57
  original_path = save_image(original)
58
  return [result_path, original_path], result_path
59
 
60
+ # ์˜ˆ์ œ ์ด๋ฏธ์ง€๋ฅผ ์ง์ ‘ PIL ๊ฐ์ฒด๋กœ ๋กœ๋“œ
61
+ example_image1 = Image.open("example_images/example1.png")
62
+ example_image2 = Image.open("example_images/example2.png")
63
+ example_image3 = Image.open("example_images/example3.png")
64
+
65
+ # ์ธํ„ฐํŽ˜์ด์Šค ์ปดํฌ๋„ŒํŠธ ์ •์˜
66
  image = gr.Image(label="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
67
  slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath")
68
  png_output = gr.File(label="PNG ๋‹ค์šด๋กœ๋“œ")
69
 
70
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
 
 
 
 
 
71
  demo = gr.Interface(
72
  process_and_download,
73
  inputs=image,
74
  outputs=[slider, png_output],
75
+ examples=[example_image1, example_image2, example_image3],
76
  title="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ",
77
  description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด BiRefNet ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐฐ๊ฒฝ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
78
  )
79
 
80
  if __name__ == "__main__":
81
+ demo.launch()