|
import os |
|
from typing import Union |
|
from skimage import io, transform |
|
import torch |
|
import torchvision |
|
from torch.autograd import Variable |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
|
|
|
|
import numpy as np |
|
from PIL import Image |
|
import glob |
|
|
|
from .data_loader import RescaleT |
|
from .data_loader import ToTensor |
|
from .data_loader import ToTensorLab |
|
from .data_loader import SalObjDataset |
|
|
|
from .u2net import U2NET |
|
from .u2net import U2NETP |
|
|
|
|
|
|
|
def normPRED(d): |
|
ma = torch.max(d) |
|
mi = torch.min(d) |
|
|
|
dn = (d-mi)/(ma-mi) |
|
|
|
return dn |
|
|
|
def save_output(image_name,pred,d_dir): |
|
|
|
predict = pred |
|
predict = predict.squeeze() |
|
predict_np = predict.cpu().data.numpy() |
|
|
|
im = Image.fromarray(predict_np*255).convert('RGB') |
|
img_name = image_name.split(os.sep)[-1] |
|
image = io.imread(image_name) |
|
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) |
|
|
|
pb_np = np.array(imo) |
|
|
|
aaa = img_name.split(".") |
|
bbb = aaa[0:-1] |
|
imidx = bbb[0] |
|
for i in range(1,len(bbb)): |
|
imidx = imidx + "." + bbb[i] |
|
|
|
imo.save(d_dir+imidx+'.png') |
|
|
|
|
|
def get_u2net_model(): |
|
model_pth = "/Users/reeteshmukul/me/model/saliency/u2net.pth" |
|
net = U2NET(3,1) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
net.load_state_dict(torch.load(model_pth, map_location=device)) |
|
net.eval() |
|
|
|
return net |
|
|
|
|
|
def get_saliency_mask(model, image_or_image_path : Union[str, np.array]): |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
if isinstance(image_or_image_path, str): |
|
image = io.imread(image_or_image_path) |
|
else: |
|
image = image_or_image_path |
|
|
|
transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) |
|
sample = transform({ |
|
'imidx' : np.array([0]), |
|
'image' : image, |
|
'label' : np.expand_dims(np.zeros(image.shape[:-1]), -1) |
|
}) |
|
|
|
input_test = sample["image"].unsqueeze(0).type(torch.FloatTensor).to(device) |
|
|
|
d1,d2,d3,d4,d5,d6,d7= model(input_test) |
|
|
|
pred = d1[:,0,:,:] |
|
pred = normPRED(pred) |
|
|
|
pred = pred.squeeze() |
|
predict_np = pred.cpu().data.numpy() |
|
|
|
im = Image.fromarray(predict_np * 255).convert("RGB") |
|
|
|
return im |