Created inference file, moved torch no grad to model.py, removed timm user warning and used a different photo from the demo for default inference image.
Browse files- __pycache__/model.cpython-39.pyc +0 -0
- inference.py +17 -0
- model.py +2 -1
__pycache__/model.cpython-39.pyc
ADDED
Binary file (29.3 kB). View file
|
|
inference.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import model
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
|
8 |
+
file = "./image.png" # input image
|
9 |
+
|
10 |
+
model = model.BEN_Base().to(device).eval() #init pipeline
|
11 |
+
|
12 |
+
model.loadcheckpoints("./BEN_Base.pth")
|
13 |
+
image = Image.open(file)
|
14 |
+
mask, foreground = model.inference(image)
|
15 |
+
|
16 |
+
mask.save("./mask.png")
|
17 |
+
foreground.save("./foreground.png")
|
model.py
CHANGED
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|
6 |
import torch.utils.checkpoint as checkpoint
|
7 |
from einops import rearrange
|
8 |
from PIL import Image, ImageFilter, ImageOps
|
9 |
-
from timm.
|
10 |
from torchvision import transforms
|
11 |
|
12 |
class Mlp(nn.Module):
|
@@ -887,6 +887,7 @@ class BEN_Base(nn.Module):
|
|
887 |
|
888 |
return final_output.sigmoid()
|
889 |
|
|
|
890 |
def inference(self,image):
|
891 |
image, h, w,original_image = rgb_loader_refiner(image)
|
892 |
|
|
|
6 |
import torch.utils.checkpoint as checkpoint
|
7 |
from einops import rearrange
|
8 |
from PIL import Image, ImageFilter, ImageOps
|
9 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
10 |
from torchvision import transforms
|
11 |
|
12 |
class Mlp(nn.Module):
|
|
|
887 |
|
888 |
return final_output.sigmoid()
|
889 |
|
890 |
+
@torch.no_grad()
|
891 |
def inference(self,image):
|
892 |
image, h, w,original_image = rgb_loader_refiner(image)
|
893 |
|