Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import SamModel, SamProcessor, pipeline | |
from PIL import Image, ImageOps | |
import numpy as np | |
import torch | |
# Constants | |
XS_YS = [(2.0, 2.0), (2.5, 2.5)] | |
WIDTH = 600 | |
# Load models | |
def load_models(): | |
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") | |
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") | |
od_pipe = pipeline("object-detection", "facebook/detr-resnet-50") | |
return model, processor, od_pipe | |
def process_image(image, model, processor, bounding_box=None, input_point=None): | |
try: | |
# Convert image to RGB mode | |
image = image.convert('RGB') | |
# Convert image to numpy array | |
image_array = np.array(image) | |
if bounding_box: | |
inputs = processor(images=image_array, input_boxes=[bounding_box], return_tensors="pt") | |
elif input_point: | |
inputs = processor(images=image_array, input_points=[[input_point]], return_tensors="pt") | |
else: | |
raise ValueError("Either bounding_box or input_point must be provided") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks, | |
inputs["original_sizes"], | |
inputs["reshaped_input_sizes"] | |
) | |
return predicted_masks[0] | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
return None | |
def display_masked_images(raw_image, predicted_mask, caption_prefix): | |
for i in range(3): | |
mask = predicted_mask[0][i] | |
int_mask = np.array(mask).astype(int) * 255 | |
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L') | |
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image) | |
st.image(final_image, caption=f"{caption_prefix} {i+1}", width=WIDTH) | |
def main(): | |
st.title("Image Segmentation with Object Detection") | |
# Introduction and How-to | |
st.markdown(""" | |
Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works: | |
- **Upload an image**: Drag and drop or use the browse files option. | |
- **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes. | |
- **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data. | |
- **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis. | |
Please note that processing takes some time. We appreciate your patience as the models do their work! | |
""") | |
# Model credits | |
st.subheader("Powered by:") | |
st.write("- Object Detection Model: `facebook/detr-resnet-50`") | |
st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`") | |
model, processor, od_pipe = load_models() | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
raw_image = Image.open(uploaded_file) | |
st.subheader("Uploaded Image") | |
st.image(raw_image, caption="Uploaded Image", width=WIDTH) | |
with st.spinner('Processing image...'): | |
# Object Detection | |
pipeline_output = od_pipe(raw_image) | |
input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output] | |
labels_format = [b['label'] for b in pipeline_output] | |
# Process bounding boxes | |
for b, l in zip(input_boxes_format, labels_format): | |
st.subheader(f'bounding box : {l}') | |
predicted_mask = process_image(raw_image, model, processor, bounding_box=b) | |
if predicted_mask is not None: | |
display_masked_images(raw_image, predicted_mask, "Masked Image") | |
# Process input points | |
for (x, y) in XS_YS: | |
point_x, point_y = raw_image.size[0] // x, raw_image.size[1] // y | |
st.subheader(f"Input points : ({1/x},{1/y})") | |
predicted_mask = process_image(raw_image, model, processor, input_point=[point_x, point_y]) | |
if predicted_mask is not None: | |
display_masked_images(raw_image, predicted_mask, "Masked Image") | |
if __name__ == "__main__": | |
main() |