Spaces:
Sleeping
Sleeping
import torch | |
import urllib | |
import streamlit as st | |
from io import BytesIO | |
from time import time | |
from PIL import Image | |
from transformers import AutoModelForVision2Seq, AutoProcessor | |
def scale_image(image: Image.Image, target_height: int = 500) -> Image.Image: | |
""" | |
Scale an image to a target height while maintaining the aspect ratio. | |
Parameters | |
---------- | |
image : Image.Image | |
The image to scale. | |
target_height : int, optional (default=500) | |
The target height of the image. | |
Returns | |
------- | |
Image.Image | |
The scaled image. | |
""" | |
width, height = image.size | |
aspect_ratio = width / height | |
target_width = int(aspect_ratio * target_height) | |
return image.resize((target_width, target_height)) | |
def upload_image() -> None: | |
""" | |
Upload an image. | |
""" | |
if st.session_state.file_uploader is not None: | |
st.session_state.image = Image.open(st.session_state.file_uploader) | |
def read_image_from_url() -> None: | |
""" | |
Read an image from a URL. | |
""" | |
if st.session_state.image_url is not None: | |
with urllib.request.urlopen(st.session_state.image_url) as response: | |
st.session_state.image = Image.open(BytesIO(response.read())) | |
def inference() -> None: | |
""" | |
Perform inference on an image and generate a caption. | |
""" | |
start_time = time() | |
outputs = st.session_state.processor( | |
images=st.session_state.image, | |
return_tensors="pt", | |
) | |
outputs = {k: v.to(st.session_state.device.lower()) for k, v in outputs.items()} | |
st.session_state.model.to(st.session_state.device.lower()) | |
logits = st.session_state.model.generate( | |
**outputs, | |
max_length=st.session_state.max_length, | |
num_beams=st.session_state.num_beams, | |
) | |
caption = st.session_state.processor.decode( | |
logits[0], skip_special_tokens=True | |
) | |
end_time = time() | |
st.session_state.inference_time = round(end_time - start_time, 2) | |
st.session_state.caption = caption | |
st.session_state.model.to("cpu") | |
torch.cuda.empty_cache() | |
def main() -> None: | |
""" | |
Main function for the Streamlit app. | |
""" | |
if "model" not in st.session_state: | |
st.session_state.model = AutoModelForVision2Seq.from_pretrained( | |
"tanthinhdt/blip-base_with-pretrained_flickr30k", | |
cache_dir="models/huggingface", | |
) | |
st.session_state.model.eval() | |
if "processor" not in st.session_state: | |
st.session_state.processor = AutoProcessor.from_pretrained( | |
"Salesforce/blip-image-captioning-base", | |
cache_dir="models/huggingface", | |
) | |
if "image" not in st.session_state: | |
st.session_state.image = None | |
if "caption" not in st.session_state: | |
st.session_state.caption = None | |
if "inference_time" not in st.session_state: | |
st.session_state.inference_time = 0.0 | |
# Set page configuration | |
st.set_page_config( | |
page_title="Image Captioning App", | |
page_icon="๐ธ", | |
initial_sidebar_state="expanded", | |
) | |
# Set sidebar layout | |
st.sidebar.header("Workspace") | |
st.sidebar.file_uploader( | |
"Upload an image", | |
type=["jpg", "jpeg", "png"], | |
accept_multiple_files=False, | |
on_change=upload_image, | |
key="file_uploader", | |
help="Upload an image to generate a caption.", | |
) | |
st.sidebar.text_input( | |
"Image URL", | |
on_change=read_image_from_url, | |
key="image_url", | |
help="Enter the URL of an image to generate a caption.", | |
) | |
st.sidebar.divider() | |
st.sidebar.header("Settings") | |
st.sidebar.selectbox( | |
label="Device", | |
options=["CPU", "CUDA"], | |
index=1 if torch.cuda.is_available() else 0, | |
key="device", | |
help="The device to use for inference.", | |
) | |
st.sidebar.number_input( | |
label="Max length", | |
min_value=32, | |
max_value=128, | |
value=64, | |
step=1, | |
key="max_length", | |
help="The maximum length of the generated caption.", | |
) | |
st.sidebar.number_input( | |
label="Number of beams", | |
min_value=1, | |
max_value=10, | |
value=4, | |
step=1, | |
key="num_beams", | |
help="The number of beams to use during decoding.", | |
) | |
# Set main layout | |
st.markdown( | |
""" | |
<h1 style='text-align: center;'> | |
Image Captioning | |
</h1> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.divider() | |
image_container = st.container(height=450) | |
st.divider() | |
col_1, col_2, col_3 = st.columns([1, 1, 2]) | |
resolution_display = col_1.empty() | |
runtime_display = col_2.empty() | |
caption_display = col_3.empty() | |
# Display the image and generate a caption | |
if st.session_state.image is not None: | |
image_container.image(scale_image(st.session_state.image, target_height=400)) | |
resolution_display.metric( | |
label="Image Resolution", | |
value=f"{st.session_state.image.width}x{st.session_state.image.height}", | |
) | |
with st.spinner("Generating caption..."): | |
inference() | |
caption_display.text_area( | |
label="Caption", | |
value=st.session_state.caption, | |
) | |
runtime_display.metric( | |
label="Inference Time", | |
value=f"{st.session_state.inference_time}s", | |
) | |
if __name__ == "__main__": | |
main() | |