|
import streamlit as st |
|
from st_pages import Page, show_pages, add_page_title, Section |
|
from lib.utils.model import get_model, get_similarities, get_detr, segment_images |
|
from lib.utils.timer import timer |
|
|
|
add_page_title() |
|
|
|
show_pages( |
|
[ |
|
Page('app.py', 'IRRA Text-To-Image-Retrival'), |
|
Section('Implementation Details'), |
|
Page('pages/losses.py', 'Loss functions'), |
|
] |
|
) |
|
|
|
st.markdown(''' |
|
A text-to-image retrieval model implemented from [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501). |
|
The uploaded images should be `384x128` with only one person in the shot. |
|
''') |
|
|
|
st.header('Inputs') |
|
caption = st.text_input('Description Input') |
|
|
|
images = st.file_uploader('Upload images', accept_multiple_files=True) |
|
if images is not None: |
|
|
|
st.image(images) |
|
|
|
st.header('Options') |
|
st.subheader('Ranks', help='How many predictions the model is allowed to make') |
|
|
|
ranks = st.slider('slider_ranks', min_value=1, max_value=10, |
|
label_visibility='collapsed', value=5) |
|
do_segment = st.checkbox('Segment images with DETR', value=False) |
|
button = st.button('Match most similar', disabled=len( |
|
images) == 0 or caption == '') |
|
|
|
|
|
if button: |
|
if do_segment: |
|
detr, processor = get_detr() |
|
images = segment_images(detr, processor, images) |
|
|
|
st.header('Results') |
|
with st.spinner('Loading model'): |
|
model = get_model() |
|
|
|
st.text( |
|
f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters') |
|
|
|
time = timer() |
|
with st.spinner('Computing and ranking similarities'): |
|
with timer() as t: |
|
similarities = get_similarities(caption, images, model).squeeze(0) |
|
elapsed = t() |
|
|
|
indices = similarities.argsort(descending=True).cpu().tolist()[:ranks] |
|
|
|
c1, c2, c3 = st.columns(3) |
|
with c1: |
|
st.subheader('Rank') |
|
with c2: |
|
st.subheader('Image') |
|
with c3: |
|
st.subheader('Cosine Similarity', |
|
help='Due to the nature of the SDM loss, the higher the similarity, the more similar the match is') |
|
|
|
for i, idx in enumerate(indices): |
|
c1, c2, c3 = st.columns(3) |
|
with c1: |
|
st.text(f'{i + 1}') |
|
with c2: |
|
st.image(images[idx]) |
|
with c3: |
|
st.text(f'{similarities[idx].cpu():.2f}') |
|
|
|
st.success(f'Done in {elapsed:.2f}s') |
|
|
|
with st.sidebar: |
|
st.title('IRRA Text-To-Image Retrival') |
|
|
|
st.subheader('Useful Links') |
|
st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)') |
|
st.markdown( |
|
'[IRRA implementation (Pytorch Lightning + Transformers)](https://github.com/grostaco/modern-IRRA)') |
|
st.markdown( |
|
'[IRRA implementation (PyTorch)](https://github.com/anosorae/IRRA/tree/main)') |
|
|