not-lain commited on
Commit
0e6c023
·
1 Parent(s): 003d203
Files changed (2) hide show
  1. app.py +41 -5
  2. requirements +13 -1
app.py CHANGED
@@ -1,12 +1,48 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
 
7
  @spaces.GPU
8
- def greet(n):
9
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
12
- demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from loadimg import load_img
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
7
+
8
+ torch.set_float32_matmul_precision(["high", "highest"][0])
9
+
10
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
11
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
12
+ )
13
+ birefnet.to("cuda")
14
+
15
+ transform_image = transforms.Compose(
16
+ [
17
+ transforms.Resize((1024, 1024)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
20
+ ]
21
+ )
22
 
 
23
 
24
  @spaces.GPU
25
+ def rmbg(image):
26
+ image = load_img().convert("RGB")
27
+ image_size = image.size
28
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
29
+ # Prediction
30
+ with torch.no_grad():
31
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
32
+ pred = preds[0].squeeze()
33
+ pred_pil = transforms.ToPILImage()(pred)
34
+ mask = pred_pil.resize(image_size)
35
+ image.putalpha(mask)
36
+ return image
37
+
38
+
39
+ rmbg_tab = gr.Interface(fn=rmbg, inputs=["text"], outputs=["image"], api_name="rmbg")
40
+
41
+ demo = gr.TabbedInterface(
42
+ [rmbg_tab],
43
+ ["remove background"],
44
+ title="Background Removal",
45
+ )
46
+
47
 
48
+ demo.launch()
 
requirements CHANGED
@@ -1,2 +1,14 @@
1
  spaces
2
- torch
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  spaces
2
+ torch
3
+ torchvision
4
+ opencv-python
5
+ tqdm
6
+ timm
7
+ prettytable
8
+ scipy
9
+ scikit-image
10
+ kornia
11
+ gradio_imageslider
12
+ transformers
13
+ huggingface_hub
14
+ loadimg