CSB261's picture
Update app.py
7d3d39b verified
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import os
import tempfile
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU
def fn(image):
if image is None:
return None, None
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
input_images = transform_image(im).unsqueeze(0).to("cuda")
# ์˜ˆ์ธก
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
im.putalpha(mask)
return im, origin
def save_image(image):
if image is None:
return None
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
image.save(temp_file.name, format="PNG")
return temp_file.name
def process_and_download(input_image):
result, original = fn(input_image)
if result is None:
return None, None
result_path = save_image(result)
original_path = save_image(original)
return [result_path, original_path], result_path
# ์˜ˆ์ œ ์ด๋ฏธ์ง€๋ฅผ ์ง์ ‘ PIL ๊ฐ์ฒด๋กœ ๋กœ๋“œ
example_image1 = Image.open("example_images/example1.png")
example_image2 = Image.open("example_images/example2.png")
example_image3 = Image.open("example_images/example3.png")
# ์ธํ„ฐํŽ˜์ด์Šค ์ปดํฌ๋„ŒํŠธ ์ •์˜
image = gr.Image(label="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath")
png_output = gr.File(label="PNG ๋‹ค์šด๋กœ๋“œ")
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
demo = gr.Interface(
process_and_download,
inputs=image,
outputs=[slider, png_output],
examples=[example_image1, example_image2, example_image3],
title="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ",
description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด BiRefNet ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐฐ๊ฒฝ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
)
if __name__ == "__main__":
demo.launch()