Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import clip | |
from PIL import Image | |
from typing import List, Tuple | |
import numpy as np | |
# Load CLIP model and preprocessing | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
# Function to predict descriptions and probabilities | |
def predict(image: Image.Image, descriptions: List[str]) -> Tuple[str, float]: | |
""" | |
Predict the best matching description for the provided image based on the given descriptions. | |
Uses the CLIP model to compute similarities between the image and text descriptions. | |
Args: | |
image (Image.Image): The input image for which the descriptions are being evaluated. | |
descriptions (List[str]): A list of textual descriptions to compare against the image. | |
Returns: | |
Tuple[str, float]: A tuple containing the best-matching description and the corresponding probability. | |
""" | |
# Preprocess the image and move it to the appropriate device | |
image = preprocess(image).unsqueeze(0).to(device) | |
# Tokenize the descriptions and move them to the appropriate device | |
text = clip.tokenize(descriptions).to(device) | |
with torch.no_grad(): | |
# Encode image and text features using the CLIP model | |
image_features = model.encode_image(image) | |
text_features = model.encode_text(text) | |
# Compute the similarity scores (logits) between image and text | |
logits_per_image, logits_per_text = model(image, text) | |
# Convert logits to probabilities | |
probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
# Return the description with the highest probability and the corresponding probability | |
return descriptions[np.argmax(probs)], np.max(probs) | |
# Streamlit app | |
def main(): | |
st.title("Image understanding model") | |
# Instructions for the user | |
st.markdown("---") | |
st.markdown("### Upload an image to test how well the model understands it") | |
# Upload image through Streamlit with a unique key | |
uploaded_image = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"], key="uploaded_image") | |
if uploaded_image is not None: | |
# Convert the uploaded image to PIL Image | |
pil_image = Image.open(uploaded_image) | |
# Limit the height of the displayed image to 400px | |
st.image(pil_image, caption="Uploaded Image.", use_column_width=True, width=200) | |
# Instructions for the user | |
st.markdown("### 2 Lies and 1 Truth") | |
st.markdown("Write 3 descriptions about the image, 1 must be true.") | |
# Get user input for descriptions | |
description1 = st.text_input("Description 1:", placeholder='A red apple') | |
description2 = st.text_input("Description 2:", placeholder='A car parked in a garage') | |
description3 = st.text_input("Description 3:", placeholder='An orange fruit on a tree') | |
descriptions = [description1, description2, description3] | |
# Button to trigger prediction | |
if st.button("Predict"): | |
if all(descriptions): | |
# Make predictions | |
best_description, best_prob = predict(pil_image, descriptions) | |
# Display the highest probability description and its probability | |
st.write(f"**Best Description:** {best_description}") | |
st.write(f"**Prediction Probability:** {best_prob:.2%}") | |
# Display progress bar for the highest probability | |
st.progress(float(best_prob)) | |
if __name__ == "__main__": | |
main() | |