Spaces:
Runtime error
Runtime error
import functools | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
from tensorflow.keras.applications import efficientnet | |
#import efficientnet | |
from tensorflow.keras.layers import TextVectorization | |
import matplotlib.pyplot as plt | |
import cv2 | |
from models import EMBED_DIM, FF_DIM, SEQ_LENGTH, ImageCaptioningModel, TransformerDecoderBlock, TransformerEncoderBlock, get_cnn_model, image_augmentation, vectorization, valid_data, decode_and_resize | |
def display_UI(): | |
import streamlit as st | |
from streamlit_option_menu import option_menu | |
import streamlit.components.v1 as html | |
import pandas as pd | |
import numpy as np | |
from pathlib import Path | |
# from PIL import Image | |
st.markdown(""" <style> .appview-container .main .block-container { | |
max-width: 100%; | |
padding-top: 1rem; | |
padding-right: {1}rem; | |
padding-left: {1}rem; | |
padding-bottom: {1}rem; | |
}</style> """, unsafe_allow_html=True) | |
#Add a logo (optional) in the sidebar | |
# logo = Image.open(r'C:\Users\13525\Desktop\Insights_Bees_logo.png') | |
# with st.sidebar: | |
# choose = option_menu("Forensic Examiner", ["Inspect Media","Comparative Analysis","About", "Contact"], | |
# icons=['camera fill', 'kanban', 'book','person lines fill'], | |
# menu_icon="app-indicator", default_index=0, | |
# styles={ | |
# "container": {"padding": "0 5 5 5 !important", "background-color": "#fafafa"}, | |
# "icon": {"color": "orange", "font-size": "25px"}, | |
# "nav-link": {"font-size": "16px", "text-align": "left", "margin":"0px", "--hover-color": "#eee"}, | |
# "nav-link-selected": {"background-color": "#02ab21"}, | |
# } | |
# ) | |
#Add the cover image for the cover page. Used a little trick to center the image | |
st.markdown(""" <style> .font { | |
font-size:25px ; font-family: 'Cooper Black'; color: #FF9633;} | |
</style> """, unsafe_allow_html=True) | |
col1, col2 = st.columns( [0.8, 0.2]) | |
with col1: # To display the header text using css style | |
st.markdown('<p class="font">Generate Caption of image</p>', unsafe_allow_html=True) | |
with col2: # To display brand logo | |
st.image('./logo.png', width=50 ) | |
# model_name = st.selectbox("Select the model...", list (all_models.keys ())) | |
uploaded_file = st.file_uploader("Choose an Image File", type=[".jpg", ".jpeg", ".png", ".PNG"], | |
accept_multiple_files=False) | |
opencv_image= None | |
if uploaded_file is not None: | |
with st.spinner('Wait for it...'): | |
# read image file and store for prediction | |
# img_file=uploaded_file.read() | |
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
opencv_image = cv2.imdecode(file_bytes, 1) | |
# Now do something with the image! For example, let's display it: | |
st.image(opencv_image, channels="BGR") | |
# bytes_data = uploaded_file.getvalue() | |
# audio_bytes = uploaded_file.read() | |
# save_folder = './data' | |
# save_path = Path(save_folder, uploaded_file.name) | |
# with open(save_path, mode='wb') as w: | |
# w.write(uploaded_file.getvalue()) | |
st.image(opencv_image, width=400 ) | |
with st.spinner('Loading the model..'): | |
cnn_model = get_cnn_model() | |
encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1) | |
decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2) | |
new_model = ImageCaptioningModel( | |
cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation, | |
) | |
new_model.load_weights('model_weights') | |
st.success(f'Model Loaded!', icon="β ") | |
# st.success(f'Reported EER for the selected model {reported_eer}%') | |
with st.spinner("Getting prediction..."): | |
vocab = vectorization.get_vocabulary() | |
index_lookup = dict(zip(range(len(vocab)), vocab)) | |
max_decoded_sentence_length = SEQ_LENGTH - 1 | |
valid_images = list(valid_data.keys()) | |
def generate_caption(): | |
# Select a random image from the validation dataset | |
sample_img = opencv_image #np.random.choice(valid_images) | |
# Read the image from the disk | |
cv2.imwrite('./uploaded_image.jpg', sample_img) | |
sample_img = decode_and_resize('./uploaded_image.jpg') | |
img = sample_img.numpy().clip(0, 255).astype(np.uint8) | |
#plt.imshow(img) | |
#plt.show() | |
# Pass the image to the CNN | |
img = tf.expand_dims(sample_img, 0) | |
img = new_model.cnn_model(img) | |
# Pass the image features to the Transformer encoder | |
encoded_img = new_model.encoder(img, training=False) | |
# Generate the caption using the Transformer decoder | |
decoded_caption = "<start> " | |
for i in range(max_decoded_sentence_length): | |
tokenized_caption = vectorization([decoded_caption])[:, :-1] | |
mask = tf.math.not_equal(tokenized_caption, 0) | |
predictions = new_model.decoder( | |
tokenized_caption, encoded_img, training=False, mask=mask | |
) | |
sampled_token_index = np.argmax(predictions[0, i, :]) | |
sampled_token = index_lookup[sampled_token_index] | |
if sampled_token == " <end>": | |
break | |
decoded_caption += " " + sampled_token | |
decoded_caption = decoded_caption.replace("<start> ", "") | |
decoded_caption = decoded_caption.replace(" <end>", "").strip() | |
return decoded_caption | |
# Check predictions for a few samples | |
caption=generate_caption() | |
# print(audio.shape) | |
if caption: | |
st.success(caption, icon="β ") | |
else: | |
# st.error(f"The Sample is spoof: \n Confidence {(prediction_value) }%", icon="π¨") | |
st.error(f"Error occured in caption generation", icon="π¨") | |
# if choose == "Comparative Analysis": | |
# st.markdown(""" <style> .font { | |
# font-size:25px ; font-family: 'Cooper Black'; color: #FF9633;} | |
# </style> """, unsafe_allow_html=True) | |
# st.markdown('<p class="font">Comparison of Models</p>', unsafe_allow_html=True) | |
# data_frame = get_data() | |
# tab1, tab2 = st.tabs(["EER", "min-TDCF"]) | |
# with tab1: | |
# data_frame["EER ASVS 2019"] = data_frame["EER ASVS 2019"].astype('float64') | |
# data_frame["EER ASVS 2021"] = data_frame["EER ASVS 2021"].astype('float64') | |
# data_frame["Cross-dataset 19-21"] = data_frame["Cross-dataset 19-21"].astype('float64') | |
# data = data_frame[["Model Name","EER ASVS 2019","EER ASVS 2021","Cross-dataset 19-21"]].reset_index(drop=True).melt('Model Name') | |
# chart=alt.Chart(data).mark_line().encode( | |
# x='Model Name', | |
# y='value', | |
# color='variable' | |
# ) | |
# st.altair_chart(chart, theme=None, use_container_width=True) | |
# with tab2: | |
# data_frame["min-TDCF ASVS 2019"] = data_frame["EER ASVS 2019"].astype('float64') | |
# data_frame["min-TDCF ASVS 2021"] = data_frame["EER ASVS 2021"].astype('float64') | |
# data_frame["min-TDCF Cross-dataset"] = data_frame["Cross-dataset 19-21"].astype('float64') | |
# data = data_frame[["Model Name","min-TDCF ASVS 2019","min-TDCF ASVS 2021","min-TDCF Cross-dataset"]].reset_index(drop=True).melt('Model Name') | |
# chart=alt.Chart(data).mark_line().encode( | |
# x='Model Name', | |
# y='value', | |
# color='variable' | |
# ) | |
# st.altair_chart(chart, theme=None, use_container_width=True) | |
# # Data table | |
# st.markdown(""" <style> .appview-container .main .block-container { | |
# max-width: 100%; | |
# padding-top: {1}rem; | |
# padding-right: {1}rem; | |
# padding-left: {1}rem; | |
# padding-bottom: {1}rem; | |
# }</style> """, unsafe_allow_html=True) | |
# st.dataframe(data_frame, use_container_width=True) | |
# if choose == "About": | |
# st.markdown(""" <style> .font { | |
# font-size:35px ; font-family: 'Cooper Black'; color: #FF9633;} | |
# </style> """, unsafe_allow_html=True) | |
# st.markdown('<p class="font">About</p>', unsafe_allow_html=True) | |
# if choose == "Contact": | |
# st.markdown(""" <style> .font { | |
# font-size:35px ; font-family: 'Cooper Black'; color: #FF9633;} | |
# </style> """, unsafe_allow_html=True) | |
# st.markdown('<p class="font">Contact Us</p>', unsafe_allow_html=True) | |
# with st.form(key='columns_in_form2',clear_on_submit=True): #set clear_on_submit=True so that the form will be reset/cleared once it's submitted | |
# #st.write('Please help us improve!') | |
# Name=st.text_input(label='Please Enter Your Name') #Collect user feedback | |
# Email=st.text_input(label='Please Enter Your Email') #Collect user feedback | |
# Message=st.text_input(label='Please Enter Your Message') #Collect user feedback | |
# submitted = st.form_submit_button('Submit') | |
# if submitted: | |
# st.write('Thanks for your contacting us. We will respond to your questions or inquiries as soon as possible!') | |
display_UI() | |