Spaces:
Runtime error
Runtime error
from io import BytesIO | |
import streamlit as st | |
import pandas as pd | |
import json | |
import os | |
import numpy as np | |
from streamlit.elements import markdown | |
from PIL import Image | |
from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import ( | |
FlaxCLIPVisionMBartForConditionalGeneration, | |
) | |
from transformers import MBart50TokenizerFast | |
from utils import ( | |
get_transformed_image, | |
) | |
import matplotlib.pyplot as plt | |
from mtranslate import translate | |
from session import _get_state | |
state = _get_state() | |
def load_model(ckpt): | |
return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt) | |
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50") | |
language_mapping = { | |
"en": "en_XX", | |
"de": "de_DE", | |
"fr": "fr_XX", | |
"es": "es_XX" | |
} | |
code_to_name = { | |
"en": "English", | |
"fr": "French", | |
"de": "German", | |
"es": "Spanish", | |
} | |
def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p): | |
lang_code = language_mapping[lang_code] | |
output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p) | |
print(output_ids) | |
output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64) | |
return output_sequence | |
def read_markdown(path, parent="./sections/"): | |
with open(os.path.join(parent, path)) as f: | |
return f.read() | |
checkpoints = ["./ckpt/ckpt-17499"] # TODO: Maybe add more checkpoints? | |
dummy_data = pd.read_csv("reference.tsv", sep="\t") | |
st.set_page_config( | |
page_title="Multilingual Image Captioning", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
st.title("Multilingual Image Captioning") | |
st.write( | |
"[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)" | |
) | |
st.sidebar.title("Generation Parameters") | |
num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.") | |
temperature = st.sidebar.select_slider(label="Temperature", options = np.arange(0.0,1.1, step=0.1), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}") | |
top_p = st.sidebar.select_slider(label = "Top-P", options = np.arange(0.0,1.1, step=0.1),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}") | |
with st.beta_expander("Usage"): | |
st.markdown(read_markdown("usage.md")) | |
with st.beta_expander("Article"): | |
st.write(read_markdown("abstract.md")) | |
st.write(read_markdown("caveats.md")) | |
# st.write("# Methodology") | |
# st.image( | |
# "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning." | |
# ) | |
st.markdown(read_markdown("pretraining.md")) | |
st.write(read_markdown("challenges.md")) | |
st.write(read_markdown("social_impact.md")) | |
st.write(read_markdown("references.md")) | |
# st.write(read_markdown("checkpoints.md")) | |
st.write(read_markdown("acknowledgements.md")) | |
first_index = 20 | |
# Init Session State | |
if state.image_file is None: | |
state.image_file = dummy_data.loc[first_index, "image_file"] | |
state.caption = dummy_data.loc[first_index, "caption"].strip("- ") | |
state.lang_id = dummy_data.loc[first_index, "lang_id"] | |
image_path = os.path.join("images", state.image_file) | |
image = plt.imread(image_path) | |
state.image = image | |
# col1, col2 = st.beta_columns([6, 4]) | |
if st.button("Get a random example", help="Get a random example from one of the seeded examples."): | |
sample = dummy_data.sample(1).reset_index() | |
state.image_file = sample.loc[0, "image_file"] | |
state.caption = sample.loc[0, "caption"].strip("- ") | |
state.lang_id = sample.loc[0, "lang_id"] | |
image_path = os.path.join("images", state.image_file) | |
image = plt.imread(image_path) | |
state.image = image | |
# col2.write("OR") | |
# uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"]) | |
# if uploaded_file is not None: | |
# state.image_file = os.path.join("images", uploaded_file.name) | |
# state.image = np.array(Image.open(uploaded_file)) | |
transformed_image = get_transformed_image(state.image) | |
new_col1, new_col2 = st.beta_columns([5,5]) | |
# Display Image | |
new_col1.image(state.image, use_column_width="always") | |
# Display Reference Caption | |
new_col2.write("**Reference Caption**: " + state.caption) | |
new_col2.markdown( | |
f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}""" | |
) | |
# Select Language | |
options = list(code_to_name.keys()) | |
lang_id = new_col2.selectbox( | |
"Language", | |
index=options.index(state.lang_id), | |
options=options, | |
format_func=lambda x: code_to_name[x], | |
help="The language in which caption is to be generated." | |
) | |
with st.spinner("Loading model..."): | |
model = load_model(checkpoints[0]) | |
sequence = [''] | |
if new_col2.button("Generate Caption", help="Generate a caption in the specified language."): | |
with st.spinner("Generating Sequence..."): | |
sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p) | |
# print(sequence) | |
if sequence!=['']: | |
st.write( | |
"**Generated Caption**: "+sequence[0] | |
) | |
st.write( | |
"**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0]) | |
) | |