Spaces:
Runtime error
Runtime error
File size: 2,505 Bytes
3030eb7 eb7a233 3030eb7 acf5692 eb7a233 3030eb7 7d3d39b 3030eb7 7d3d39b 3030eb7 eb7a233 3030eb7 7d3d39b eb7a233 3030eb7 7d3d39b 8cc2603 eb7a233 7d3d39b eb7a233 3030eb7 eb7a233 3030eb7 7d3d39b eb7a233 9ea4ecb 7d3d39b eb7a233 7d3d39b eb7a233 3030eb7 7d3d39b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import os
import tempfile
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU
def fn(image):
if image is None:
return None, None
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
input_images = transform_image(im).unsqueeze(0).to("cuda")
# ์์ธก
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
im.putalpha(mask)
return im, origin
def save_image(image):
if image is None:
return None
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
image.save(temp_file.name, format="PNG")
return temp_file.name
def process_and_download(input_image):
result, original = fn(input_image)
if result is None:
return None, None
result_path = save_image(result)
original_path = save_image(original)
return [result_path, original_path], result_path
# ์์ ์ด๋ฏธ์ง๋ฅผ ์ง์ PIL ๊ฐ์ฒด๋ก ๋ก๋
example_image1 = Image.open("example_images/example1.png")
example_image2 = Image.open("example_images/example2.png")
example_image3 = Image.open("example_images/example3.png")
# ์ธํฐํ์ด์ค ์ปดํฌ๋ํธ ์ ์
image = gr.Image(label="์ด๋ฏธ์ง ์
๋ก๋")
slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath")
png_output = gr.File(label="PNG ๋ค์ด๋ก๋")
# Gradio ์ธํฐํ์ด์ค ๊ตฌ์ฑ
demo = gr.Interface(
process_and_download,
inputs=image,
outputs=[slider, png_output],
examples=[example_image1, example_image2, example_image3],
title="๋ฐฐ๊ฒฝ ์ ๊ฑฐ",
description="์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ๋ฉด BiRefNet ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋ฐฐ๊ฒฝ์ ์ ๊ฑฐํฉ๋๋ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํ์ผ๋ก ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค."
)
if __name__ == "__main__":
demo.launch()
|