|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoImageProcessor |
|
|
|
from models.mobilevit import MobileVIT |
|
|
|
|
|
MODEL_CHECKPOINT = "mmenendezg/mobilevit-fluorescent-neuronal-cells" |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
DEVICE = torch.device("mps:0") |
|
ACCELERATOR = "mps" |
|
elif torch.cuda.is_available(): |
|
DEVICE = torch.device("cuda") |
|
ACCELERATOR = "gpu" |
|
else: |
|
DEVICE = torch.device("cpu") |
|
ACCELERATOR = "cpu" |
|
|
|
|
|
def single_prediction(image): |
|
|
|
mobilevit_model = MobileVIT() |
|
mobilevit_model.to(DEVICE) |
|
|
|
image_processor = AutoImageProcessor.from_pretrained( |
|
MODEL_CHECKPOINT, do_reduce_labels=False |
|
) |
|
|
|
image = image.convert("RGB") |
|
|
|
np_image = np.asarray(image, dtype=np.uint8) |
|
|
|
processed_image = image_processor(images=np_image, return_tensors="pt") |
|
processed_image.to(DEVICE) |
|
|
|
logits = mobilevit_model.model(pixel_values=processed_image["pixel_values"]) |
|
post_processed_image = image_processor.post_process_semantic_segmentation( |
|
outputs=logits, target_sizes=[(np_image.shape[0], np_image.shape[1])] |
|
) |
|
|
|
mask = post_processed_image[0].data.cpu().numpy().astype(np.uint8) * 255 |
|
mask = Image.fromarray(mask) |
|
|
|
return mask |
|
|