beingcognitive's picture
For Tech Campus class
9843137 verified
import streamlit as st
from transformers import SamModel, SamProcessor, pipeline
from PIL import Image, ImageOps
import numpy as np
import torch
# Constants
XS_YS = [(2.0, 2.0), (2.5, 2.5)]
WIDTH = 600
# Load models
@st.cache_resource
def load_models():
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
return model, processor, od_pipe
def process_image(image, model, processor, bounding_box=None, input_point=None):
try:
# Convert image to RGB mode
image = image.convert('RGB')
# Convert image to numpy array
image_array = np.array(image)
if bounding_box:
inputs = processor(images=image_array, input_boxes=[bounding_box], return_tensors="pt")
elif input_point:
inputs = processor(images=image_array, input_points=[[input_point]], return_tensors="pt")
else:
raise ValueError("Either bounding_box or input_point must be provided")
with torch.no_grad():
outputs = model(**inputs)
predicted_masks = processor.image_processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"]
)
return predicted_masks[0]
except Exception as e:
st.error(f"Error processing image: {str(e)}")
return None
def display_masked_images(raw_image, predicted_mask, caption_prefix):
for i in range(3):
mask = predicted_mask[0][i]
int_mask = np.array(mask).astype(int) * 255
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
st.image(final_image, caption=f"{caption_prefix} {i+1}", width=WIDTH)
def main():
st.title("Image Segmentation with Object Detection")
# Introduction and How-to
st.markdown("""
Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works:
- **Upload an image**: Drag and drop or use the browse files option.
- **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes.
- **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data.
- **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis.
Please note that processing takes some time. We appreciate your patience as the models do their work!
""")
# Model credits
st.subheader("Powered by:")
st.write("- Object Detection Model: `facebook/detr-resnet-50`")
st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`")
model, processor, od_pipe = load_models()
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
raw_image = Image.open(uploaded_file)
st.subheader("Uploaded Image")
st.image(raw_image, caption="Uploaded Image", width=WIDTH)
with st.spinner('Processing image...'):
# Object Detection
pipeline_output = od_pipe(raw_image)
input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output]
labels_format = [b['label'] for b in pipeline_output]
# Process bounding boxes
for b, l in zip(input_boxes_format, labels_format):
st.subheader(f'bounding box : {l}')
predicted_mask = process_image(raw_image, model, processor, bounding_box=b)
if predicted_mask is not None:
display_masked_images(raw_image, predicted_mask, "Masked Image")
# Process input points
for (x, y) in XS_YS:
point_x, point_y = raw_image.size[0] // x, raw_image.size[1] // y
st.subheader(f"Input points : ({1/x},{1/y})")
predicted_mask = process_image(raw_image, model, processor, input_point=[point_x, point_y])
if predicted_mask is not None:
display_masked_images(raw_image, predicted_mask, "Masked Image")
if __name__ == "__main__":
main()