ask-me-anything / app.py
kendrickfff's picture
Change ResNet50 to DETR for object detection
a963f7f verified
raw
history blame
4.45 kB
import os
import gradio as gr
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI # Import for Gemini
from PIL import Image
import json
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
import requests
# Load credentials (stringified JSON) from environment variable
credentials_string = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
if not credentials_string:
raise ValueError("GOOGLE_APPLICATION_CREDENTIALS is not set in the environment!")
# Parse the stringified JSON back to a Python dictionary
credentials = json.loads(credentials_string)
# Save the credentials to a temporary JSON file (required by Google SDKs)
with open("service_account.json", "w") as f:
json.dump(credentials, f)
# Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the temporary file
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account.json"
# Initialize Gemini model
llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
# Initialize DETR model and processor
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
model.eval()
# Global chat history variable
chat_history = []
def chat_with_gemini(message):
global chat_history
# Get a response from the language model
bot_response = llm.predict(message) # This will interact with the Gemini model
chat_history.append((message, bot_response))
return chat_history
def analyze_image(image_path):
global chat_history
# Load and preprocess image
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
# Inference
with torch.no_grad():
outputs = model(**inputs)
# Extract predictions
logits = outputs.logits
boxes = outputs.pred_boxes
# Filter predictions by high confidence scores
scores = logits.softmax(-1)[0, :, :-1].max(-1).values
high_scores_indices = scores > 0.9 # Adjust the threshold as needed
predicted_classes = logits.softmax(-1)[0, high_scores_indices, :-1].argmax(-1)
predicted_boxes = boxes[0, high_scores_indices].tolist()
# Map class IDs to labels
labels = [processor.config.id2label[idx.item()] for idx in predicted_classes]
# Combine predictions
predictions = [{"label": label, "box": box} for label, box in zip(labels, predicted_boxes)]
# Create response
if predictions:
detected_objects = ', '.join([p["label"] for p in predictions])
bot_response = f"The image contains: {detected_objects}."
else:
bot_response = "No objects with high confidence were detected."
chat_history.append(("Uploaded an image for analysis", bot_response))
return chat_history
# Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Ken Chatbot")
gr.Markdown("Ask me anything or upload an image for analysis!")
# Chatbot display without "User" or "Bot" labels
chatbot = gr.Chatbot(elem_id="chatbot")
# User input components
msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
send_btn = gr.Button("Send")
img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
# Define interactions
def handle_text_message(message):
return chat_with_gemini(message)
def handle_image_upload(image_path):
return analyze_image(image_path)
# Set up Gradio components with Enter key for sending
msg.submit(handle_text_message, msg, chatbot)
send_btn.click(handle_text_message, msg, chatbot)
send_btn.click(lambda: "", None, msg) # Clear input field
img_upload.change(handle_image_upload, img_upload, chatbot)
# Custom CSS for styling without usernames
gr.HTML("""
<style>
#chatbot .message-container {
display: flex;
flex-direction: column;
margin-bottom: 10px;
max-width: 70%;
}
#chatbot .message {
border-radius: 15px;
padding: 10px;
margin: 5px 0;
word-wrap: break-word;
}
#chatbot .message.user {
background-color: #DCF8C6;
margin-left: auto;
text-align: right;
}
#chatbot .message.bot {
background-color: #E1E1E1;
margin-right: auto;
text-align: left;
}
</style>
""")
# Launch for Hugging Face Spaces
demo.launch()