ask-me-anything / app.py
kendrickfff's picture
Update app.py
e345a10 verified
raw
history blame
3.59 kB
import os
import gradio as gr
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI # Import for Gemini
from PIL import Image
import torch
from torchvision import models, transforms
import json
import requests
# Load credentials from Hugging Face's Secret Manager
hf_token = os.environ.get("HF_TOKEN") # Assuming the Hugging Face API token is set in environment
# Initialize Gemini model using Hugging Face
llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro') # You can change 'gemini-1.5-pro' to the specific model you need
# Load a pre-trained ResNet50 model for image classification
model = models.resnet50(pretrained=True)
model.eval()
# Transformation pipeline for image preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load ImageNet labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
labels = json.loads(requests.get(LABELS_URL).text)
# Global chat history variable
chat_history = []
def chat_with_gemini(message):
global chat_history
# Get a response from the language model
bot_response = llm.predict(message) # This will interact with the Gemini model
chat_history.append((message, bot_response))
return chat_history
def analyze_image(image_path):
global chat_history
# Open, preprocess, and classify the image
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image_tensor)
_, predicted_idx = outputs.max(1)
label = labels[predicted_idx.item()]
bot_response = f"The image seems to be: {label}."
chat_history.append(("Uploaded an image for analysis", bot_response))
return chat_history
# Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Ken Chatbot")
gr.Markdown("Ask me anything or upload an image for analysis!")
# Chatbot display without "User" or "Bot" labels
chatbot = gr.Chatbot(elem_id="chatbot")
# User input components
msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
send_btn = gr.Button("Send")
img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
# Define interactions
def handle_text_message(message):
return chat_with_gemini(message)
def handle_image_upload(image_path):
return analyze_image(image_path)
# Set up Gradio components with Enter key for sending
msg.submit(handle_text_message, msg, chatbot)
send_btn.click(handle_text_message, msg, chatbot)
send_btn.click(lambda: "", None, msg) # Clear input field
img_upload.change(handle_image_upload, img_upload, chatbot)
# Custom CSS for styling without usernames
gr.HTML("""
<style>
#chatbot .message-container {
display: flex;
flex-direction: column;
margin-bottom: 10px;
max-width: 70%;
}
#chatbot .message {
border-radius: 15px;
padding: 10px;
margin: 5px 0;
word-wrap: break-word;
}
#chatbot .message.user {
background-color: #DCF8C6;
margin-left: auto;
text-align: right;
}
#chatbot .message.bot {
background-color: #E1E1E1;
margin-right: auto;
text-align: left;
}
</style>
""")
# Launch for Hugging Face Spaces
demo.launch()