petergpt commited on
Commit
623e1bf
·
verified ·
1 Parent(s): 536044a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForImageSegmentation
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import gradio as gr
6
+
7
+ # Load the model from Hugging Face
8
+ birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ birefnet.to(device)
11
+ birefnet.eval()
12
+
13
+ # Define the transform to preprocess the input image
14
+ image_size = (1024, 1024)
15
+ transform_image = transforms.Compose([
16
+ transforms.Resize(image_size),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ def extract_object(image):
22
+ input_images = transform_image(image).unsqueeze(0).to(device)
23
+ with torch.no_grad():
24
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
25
+ pred = preds[0].squeeze()
26
+ pred_pil = transforms.ToPILImage()(pred)
27
+ mask = pred_pil.resize(image.size)
28
+ image_with_alpha = image.convert("RGBA")
29
+ image_with_alpha.putalpha(mask)
30
+ return image_with_alpha
31
+
32
+ iface = gr.Interface(
33
+ fn=extract_object,
34
+ inputs=gr.Image(type="pil", label="Upload Image"),
35
+ outputs=gr.Image(type="pil", label="Segmented Image"),
36
+ title="BiRefNet Background Removal",
37
+ description="Upload an image and get the foreground object extracted."
38
+ )
39
+
40
+ if __name__ == "__main__":
41
+ iface.launch()