Spaces:
Runtime error
Runtime error
File size: 2,517 Bytes
8d9306e 6f0178d 8d9306e c951094 374fa3e 8d9306e 0bb133b f884ea7 6f0178d c951094 6f0178d f884ea7 d28411b 8d9306e 144ec50 8d9306e 6f0178d 144ec50 6f0178d f705683 8d9306e 6f0178d 144ec50 6f0178d 144ec50 6f0178d 9a6a97f 8d9306e 6f0178d 144ec50 6f0178d 8d9306e 686f21e 8d9306e d28411b c951094 6f0178d 8d9306e 686f21e c951094 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import streamlit as st
import requests
# Designing the interface
st.title("🖼️ Image Captioning Demo 📝")
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
st.sidebar.markdown(
"""
An image captioning model by combining ViT model with GPT2 model.
The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder
framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html).
The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
[Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
"""
)
with st.spinner('Loading and compiling ViT-GPT2 model ...'):
from model import *
random_image_id = get_random_image_id()
st.sidebar.title("Select a sample image")
sample_image_id = st.sidebar.selectbox(
"Please choose a sample image",
sample_image_ids
)
if st.sidebar.button("Random COCO 2017 (val) images"):
random_image_id = get_random_image_id()
sample_image_id = "None"
image_id = random_image_id
if sample_image_id != "None":
assert type(sample_image_id) == int
image_id = sample_image_id
sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
sample_path = os.path.join(sample_dir, sample_name)
if os.path.isfile(sample_path):
image = Image.open(sample_path)
else:
url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
image = Image.open(requests.get(url, stream=True).raw)
width, height = image.size
resized = image
if height > 384:
width = int(width / height * 384)
height = 384
resized = resized.resize(size=(width, height))
if width > 512:
width = 512
height = int(height / width * 512)
resized = resized.resize(size=(width, height))
st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
show = st.image(resized)
show.image(resized, '\n\nSelected Image')
resized.close()
# For newline
st.sidebar.write('\n')
with st.spinner('Generating image caption ...'):
caption = predict(image)
caption_en = caption
st.header(f'Predicted caption:\n\n')
st.subheader(caption_en)
st.sidebar.header("ViT-GPT2 predicts:")
st.sidebar.write(f"**English**: {caption}")
image.close()
|