u2net-saliency / u2net_inference.py
reeteshmukul's picture
model organization for u2net
85f8cd2
raw
history blame
2.46 kB
import os
from typing import Union
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import numpy as np
from PIL import Image
import glob
from .data_loader import RescaleT
from .data_loader import ToTensor
from .data_loader import ToTensorLab
from .data_loader import SalObjDataset
from .u2net import U2NET # full size version 173.6 MB
from .u2net import U2NETP # small version u2net 4.7 MB
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def save_output(image_name,pred,d_dir):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split(os.sep)[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
pb_np = np.array(imo)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
imo.save(d_dir+imidx+'.png')
def get_u2net_model():
model_pth = "/Users/reeteshmukul/me/model/saliency/u2net.pth"
net = U2NET(3,1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.load_state_dict(torch.load(model_pth, map_location=device))
net.eval()
return net
def get_saliency_mask(model, image_or_image_path : Union[str, np.array]):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if isinstance(image_or_image_path, str):
image = io.imread(image_or_image_path)
else:
image = image_or_image_path
transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
sample = transform({
'imidx' : np.array([0]),
'image' : image,
'label' : np.expand_dims(np.zeros(image.shape[:-1]), -1)
})
input_test = sample["image"].unsqueeze(0).type(torch.FloatTensor).to(device)
d1,d2,d3,d4,d5,d6,d7= model(input_test)
pred = d1[:,0,:,:]
pred = normPRED(pred)
pred = pred.squeeze()
predict_np = pred.cpu().data.numpy()
im = Image.fromarray(predict_np * 255).convert("RGB")
return im