Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch,torchvision | |
import detectron2 | |
import pickle | |
from detectron2.utils.logger import setup_logger | |
logger = setup_logger() | |
from detectron2.engine import DefaultPredictor | |
from detectron2.projects import point_rend | |
from viz_app import plot_single_image # Correct import | |
from PIL import Image | |
import numpy as np | |
from detectron2.data import MetadataCatalog, DatasetCatalog | |
from detectron2.utils.visualizer import Visualizer | |
from collections import Counter | |
# Define configuration parameters | |
output_directories = { | |
"4 Classes": "output/pointrend_4_cls_6_nest_8k/", | |
} | |
# Function to load the configuration | |
def load_configuration(output_dir): | |
with open(output_dir + "cfg.pickle", "rb") as f: | |
return pickle.load(f) | |
# Function to run instance segmentation | |
def run_instance_segmentation(im, predictor, num_classes): | |
outputs = plot_single_image(im, predictor,num_classes) | |
masks = outputs['instances'].pred_masks.cpu().numpy() | |
classes = outputs['instances'].pred_classes.cpu().numpy() | |
if num_classes == 2: | |
# Create masks for Nest and Inorganic Material | |
nest_masks = masks[classes == 0] | |
inorganic_material_masks = masks[classes == 1] | |
# Calculate the total pixel area of Nest masks | |
total_nest_area = np.sum(nest_masks) | |
# Calculate the total area of Inorganic Material masks | |
total_inorganic_area = np.sum(inorganic_material_masks) | |
# Contar la cantidad de máscaras por categoría | |
mask_counts = Counter(classes) | |
elif num_classes == 4: | |
# Create masks for Nest and Inorganic Material | |
plastic_masks = masks[classes == 0] | |
fishing_net_masks = masks[classes == 1] | |
rope_cloth_masks = masks[classes == 2] | |
nest_masks = masks[classes == 3] | |
# Calculate the total pixel area of Nest masks | |
total_nest_area = np.sum(nest_masks) | |
# Calculate the total area of Inorganic Material masks | |
plastic_area = np.sum(plastic_masks) | |
fishing_net_area = np.sum(fishing_net_masks) | |
rope_cloth_area = np.sum(rope_cloth_masks) | |
total_inorganic_area = np.sum([plastic_area, fishing_net_area, rope_cloth_area]) | |
# Contar la cantidad de máscaras por categoría | |
mask_counts = Counter(classes) | |
return total_nest_area, total_inorganic_area, mask_counts | |
def click_instance_segmentation(image, model_selection, predictor): | |
im = np.array(image)[:, :, ::-1] | |
if st.button("Run Instance Segmentation"): | |
num_classes = 2 if model_selection == "2 Classes" else 4 | |
total_nest_area, total_inorganic_area, mask_counts = run_instance_segmentation(im, predictor, num_classes) | |
# Calculate the percentage of inorganic material within the Nest | |
percentage_inorganic_in_nest = (total_inorganic_area / total_nest_area) * 100 | |
st.write(f"Percentage of Inorganic Material in Nest: {percentage_inorganic_in_nest:.2f}%") | |
category_list = ['Plastic', 'Fishing_Net', 'Rope/Cloth', "Nest"] | |
# Mostrar la cantidad de máscaras por categoría | |
st.write("Mask Counts by Category:") | |
for category, count in mask_counts.items(): | |
st.write(f"Category {category_list[category]}: {count} masks") | |
def app(): | |
# Create a sidebar to select the model | |
model_selection = st.sidebar.radio("Select Model", list(output_directories.keys())) | |
output_dir = output_directories[model_selection] | |
cfg = load_configuration(output_dir) | |
cfg.MODEL.WEIGHTS = output_dir + "model_best.pth" | |
cfg.MODEL.DEVICE = "cpu" | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.65 | |
predictor = DefaultPredictor(cfg) | |
default_image_path = 'image/DJI_0228_frame_0022_object_0.jpg' | |
st.header('Please upload an image') | |
file = st.file_uploader('', type=['png', 'jpg', 'jpeg']) | |
if file: | |
image = Image.open(file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
click_instance_segmentation(image, model_selection, predictor) | |
else: | |
st.write("No image uploaded. Using default example image.") | |
image = Image.open(default_image_path) | |
st.image(image, caption="Default Example Image", use_column_width=True) | |
click_instance_segmentation(image, model_selection, predictor) | |
app() | |