Lee-Shang's picture
Update app.py
2911a7c
import streamlit as st
import sahi.utils.file
import sahi.utils.mmdet
from sahi import AutoDetectionModel
from PIL import Image
import random
from utils import sahi_mmdet_inference
from streamlit_image_comparison import image_comparison
MMDET_YOLOX_TINY_MODEL_URL = "https://huggingface.co/fcakyon/mmdet-yolox-tiny/resolve/main/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth"
MMDET_YOLOX_TINY_MODEL_PATH = "yolox.pt"
MMDET_YOLOX_TINY_CONFIG_URL = "https://huggingface.co/fcakyon/mmdet-yolox-tiny/raw/main/yolox_tiny_8x8_300e_coco.py"
MMDET_YOLOX_TINY_CONFIG_PATH = "config.py"
IMAGE_TO_URL = {
"apple_tree.jpg": "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
"highway.jpg": "https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
"highway2.jpg": "https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
"highway3.jpg": "https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
"highway2-yolox.jpg": "https://user-images.githubusercontent.com/34196005/143309873-c0c1f31c-c42e-4a36-834e-da0a2336bb19.jpg",
"highway2-sahi.jpg": "https://user-images.githubusercontent.com/34196005/143309867-42841f5a-9181-4d22-b570-65f90f2da231.jpg",
}
slice_size=512
overlap_ratio=0.2
postprocess_match_metric = 'IOU'
postprocess_type = 'NMS'
postprocess_match_threshold = 0.5
postprocess_class_agnostic = True
@st.cache(allow_output_mutation=True, show_spinner=False)
def download_comparison_images():
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/143309873-c0c1f31c-c42e-4a36-834e-da0a2336bb19.jpg",
"highway2-yolox.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/143309867-42841f5a-9181-4d22-b570-65f90f2da231.jpg",
"highway2-sahi.jpg",
)
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_model():
sahi.utils.file.download_from_url(
MMDET_YOLOX_TINY_MODEL_URL,
MMDET_YOLOX_TINY_MODEL_PATH,
)
sahi.utils.file.download_from_url(
MMDET_YOLOX_TINY_CONFIG_URL,
MMDET_YOLOX_TINY_CONFIG_PATH,
)
detection_model = AutoDetectionModel.from_pretrained(
model_type='mmdet',
model_path=MMDET_YOLOX_TINY_MODEL_PATH,
config_path=MMDET_YOLOX_TINY_CONFIG_PATH,
confidence_threshold=0.5,
device="cpu",
)
return detection_model
class SpinnerTexts:
def __init__(self):
self.ind_history_list = []
self.text_list = [
"Loading...",
]
def _store(self, ind):
if len(self.ind_history_list) == 6:
self.ind_history_list.pop(0)
self.ind_history_list.append(ind)
def get(self):
ind = 0
while ind in self.ind_history_list:
ind = random.randint(0, len(self.text_list) - 1)
self._store(ind)
return self.text_list[ind]
st.set_page_config(
page_title="A Demonstration of SARAI's Utility",
page_icon="🐦",
layout="wide",
initial_sidebar_state="auto",
)
download_comparison_images()
if "last_spinner_texts" not in st.session_state:
st.session_state["last_spinner_texts"] = SpinnerTexts()
if "output_1" not in st.session_state:
st.session_state["output_1"] = Image.open("highway2-yolox.jpg")
if "output_2" not in st.session_state:
st.session_state["output_2"] = Image.open("highway2-sahi.jpg")
st.markdown(
"""
<h2 style='text-align: center'>
A Demonstration of SARAI's Utility
</h2>
""",
unsafe_allow_html=True,
)
st.write("##")
with st.expander("Instructions for Use"):
st.markdown(
"""
<p>
1. Upload or select the input image
<br />
2. Press "Perform Prediction" to start image processing"
</p>
""",
unsafe_allow_html=True,
)
st.write("##")
col1, col2, col3 = st.columns([4, 1, 6])
with col1:
st.markdown(f"##### Set input image:")
# set input image by upload
image_file = st.file_uploader(
"Upload an image:", type=["jpg", "jpeg", "png"]
)
# set input image from exapmles
def slider_func(option):
option_to_id = {
"apple_tree.jpg": str(1),
"highway.jpg": str(2),
"highway2.jpg": str(3),
"highway3.jpg": str(4),
}
return option_to_id[option]
slider = st.select_slider(
"Or select from our sample images:",
options=["apple_tree.jpg", "highway.jpg", "highway2.jpg", "highway3.jpg"],
format_func=slider_func,
value="highway2.jpg",
)
# visualize input image
if image_file is not None:
image = Image.open(image_file)
else:
image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[slider])
st.image(image, width=325)
with col3:
st.markdown(f"##### YOLOX Standard vs SARAI Prediction:")
static_component = image_comparison(
img1=st.session_state["output_1"],
img2=st.session_state["output_2"],
label1="YOLOX",
label2="SARAI",
width=700,
starting_position=50,
show_labels=True,
make_responsive=True,
in_memory=True,
)
col1, col2, col3, col4, col5= st.columns([1, 2, 4, 2, 2])
with col2:
# submit button
submit = st.button("Perform Prediction")
if submit:
# perform prediction
with st.spinner(
text="Downloading model weight.. "
+ st.session_state["last_spinner_texts"].get()
):
detection_model = get_model()
image_size = 416
with st.spinner(
text="Performing prediction.. " + st.session_state["last_spinner_texts"].get()
):
output_1, output_2 = sahi_mmdet_inference(
image,
detection_model,
image_size=image_size,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap_ratio,
overlap_width_ratio=overlap_ratio,
postprocess_type=postprocess_type,
postprocess_match_metric=postprocess_match_metric,
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
)
st.session_state["output_1"] = output_1
st.session_state["output_2"] = output_2
with col4:
st.markdown(f"##### Slide to Compare")