File size: 799 Bytes
2a3e831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from utils.download import attempt_download_from_hub
import segmentation_models_pytorch as smp
from utils.dataloader import *
import torch


def unet_prediction(input_path, model_path):
    model_path = attempt_download_from_hub(model_path)
    best_model = torch.load(model_path)
    preprocessing_fn = smp.encoders.get_preprocessing_fn('efficientnet-b6', 'imagenet')

    test_dataset = Dataset(input_path, augmentation=get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))
    image = test_dataset.get()
    
    x_tensor = torch.from_numpy(image).to("cuda").unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())*255

    # Save the predicted mask
    cv2.imwrite("output.png", pr_mask)
    return 'output.png'