File size: 2,956 Bytes
ba2ab36
9ff0cd2
 
 
ba2ab36
3617c74
 
 
 
 
 
 
 
 
 
 
 
 
 
ba2ab36
 
 
 
 
 
9ff0cd2
 
ba2ab36
 
c5c3fa2
ba2ab36
9ff0cd2
 
 
 
 
ba2ab36
 
 
9ff0cd2
 
 
 
ba2ab36
 
 
 
9ff0cd2
 
 
9ebc77b
ba2ab36
9ebc77b
 
 
ba2ab36
c5c3fa2
9ff0cd2
3617c74
 
 
 
 
 
9ff0cd2
 
 
ba2ab36
c5c3fa2
ba2ab36
3617c74
ba2ab36
 
c5c3fa2
3617c74
 
9ebc77b
 
3617c74
 
 
 
 
9ff0cd2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
from st_pages import Page, show_pages, add_page_title, Section
from lib.utils.model import get_model, get_similarities, get_detr, segment_images
from lib.utils.timer import timer

add_page_title()

show_pages(
    [
        Page('app.py', 'IRRA Text-To-Image-Retrival'),
        Section('Implementation Details'),
        Page('pages/losses.py', 'Loss functions'),
    ]
)

st.markdown('''
            A text-to-image retrieval model implemented from [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501).
            The uploaded images should be `384x128` with only one person in the shot. 
            ''')

st.header('Inputs')
caption = st.text_input('Description Input')

images = st.file_uploader('Upload images', accept_multiple_files=True)
if images is not None:

    st.image(images)  # type: ignore

st.header('Options')
st.subheader('Ranks', help='How many predictions the model is allowed to make')

ranks = st.slider('slider_ranks', min_value=1, max_value=10,
                  label_visibility='collapsed', value=5)
do_segment = st.checkbox('Segment images with DETR', value=False)
button = st.button('Match most similar', disabled=len(
    images) == 0 or caption == '')


if button:
    if do_segment:
        detr, processor = get_detr()
        images = segment_images(detr, processor, images)

    st.header('Results')
    with st.spinner('Loading model'):
        model = get_model()

    st.text(
        f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')

    time = timer()
    with st.spinner('Computing and ranking similarities'):
        with timer() as t:
            similarities = get_similarities(caption, images, model).squeeze(0)
    elapsed = t()

    indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]

    c1, c2, c3 = st.columns(3)
    with c1:
        st.subheader('Rank')
    with c2:
        st.subheader('Image')
    with c3:
        st.subheader('Cosine Similarity',
                     help='Due to the nature of the SDM loss, the higher the similarity, the more similar the match is')

    for i, idx in enumerate(indices):
        c1, c2, c3 = st.columns(3)
        with c1:
            st.text(f'{i + 1}')
        with c2:
            st.image(images[idx])
        with c3:
            st.text(f'{similarities[idx].cpu():.2f}')

    st.success(f'Done in {elapsed:.2f}s')

with st.sidebar:
    st.title('IRRA Text-To-Image Retrival')

    st.subheader('Useful Links')
    st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
    st.markdown(
        '[IRRA implementation (Pytorch Lightning + Transformers)](https://github.com/grostaco/modern-IRRA)')
    st.markdown(
        '[IRRA implementation (PyTorch)](https://github.com/anosorae/IRRA/tree/main)')