import gradio as gr from gradio_imageslider import ImageSlider from loadimg import load_img import spaces from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms from PIL import Image import os import tempfile torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cuda") transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) @spaces.GPU def fn(image): if image is None: return None, None im = load_img(image, output_type="pil") im = im.convert("RGB") image_size = im.size origin = im.copy() input_images = transform_image(im).unsqueeze(0).to("cuda") # 예측 with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) im.putalpha(mask) return im, origin def save_image(image): if image is None: return None with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: image.save(temp_file.name, format="PNG") return temp_file.name def process_and_download(input_image): result, original = fn(input_image) if result is None: return None, None result_path = save_image(result) original_path = save_image(original) return [result_path, original_path], result_path # 예제 이미지를 직접 PIL 객체로 로드 example_image1 = Image.open("example_images/example1.png") example_image2 = Image.open("example_images/example2.png") example_image3 = Image.open("example_images/example3.png") # 인터페이스 컴포넌트 정의 image = gr.Image(label="이미지 업로드") slider = ImageSlider(label="배경 제거 결과", type="filepath") png_output = gr.File(label="PNG 다운로드") # Gradio 인터페이스 구성 demo = gr.Interface( process_and_download, inputs=image, outputs=[slider, png_output], examples=[example_image1, example_image2, example_image3], title="배경 제거", description="이미지를 업로드하면 BiRefNet 모델을 사용하여 배경을 제거합니다. 결과를 PNG 파일로 다운로드할 수 있습니다." ) if __name__ == "__main__": demo.launch()