Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio_imageslider import ImageSlider | |
from loadimg import load_img | |
import spaces | |
from transformers import AutoModelForImageSegmentation | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import os | |
import tempfile | |
torch.set_float32_matmul_precision(["high", "highest"][0]) | |
birefnet = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet", trust_remote_code=True | |
) | |
birefnet.to("cuda") | |
transform_image = transforms.Compose( | |
[ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
def fn(image): | |
if image is None: | |
return None, None | |
im = load_img(image, output_type="pil") | |
im = im.convert("RGB") | |
image_size = im.size | |
origin = im.copy() | |
input_images = transform_image(im).unsqueeze(0).to("cuda") | |
# ์์ธก | |
with torch.no_grad(): | |
preds = birefnet(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image_size) | |
im.putalpha(mask) | |
return im, origin | |
def save_image(image): | |
if image is None: | |
return None | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
image.save(temp_file.name, format="PNG") | |
return temp_file.name | |
def process_and_download(input_image): | |
result, original = fn(input_image) | |
if result is None: | |
return None, None | |
result_path = save_image(result) | |
original_path = save_image(original) | |
return [result_path, original_path], result_path | |
# ์์ ์ด๋ฏธ์ง๋ฅผ ์ง์ PIL ๊ฐ์ฒด๋ก ๋ก๋ | |
example_image1 = Image.open("example_images/example1.png") | |
example_image2 = Image.open("example_images/example2.png") | |
example_image3 = Image.open("example_images/example3.png") | |
# ์ธํฐํ์ด์ค ์ปดํฌ๋ํธ ์ ์ | |
image = gr.Image(label="์ด๋ฏธ์ง ์ ๋ก๋") | |
slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath") | |
png_output = gr.File(label="PNG ๋ค์ด๋ก๋") | |
# Gradio ์ธํฐํ์ด์ค ๊ตฌ์ฑ | |
demo = gr.Interface( | |
process_and_download, | |
inputs=image, | |
outputs=[slider, png_output], | |
examples=[example_image1, example_image2, example_image3], | |
title="๋ฐฐ๊ฒฝ ์ ๊ฑฐ", | |
description="์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํ๋ฉด BiRefNet ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋ฐฐ๊ฒฝ์ ์ ๊ฑฐํฉ๋๋ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํ์ผ๋ก ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |