Spaces:
Running
Running
import os | |
import gradio as gr | |
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI # Import for Gemini | |
from PIL import Image | |
import torch | |
from torchvision import models, transforms | |
import json | |
import requests | |
# Load credentials from Hugging Face's Secret Manager | |
hf_token = os.environ.get("HF_TOKEN") # Assuming the Hugging Face API token is set in environment | |
# Initialize Gemini model using Hugging Face | |
llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro') # You can change 'gemini-1.5-pro' to the specific model you need | |
# Load a pre-trained ResNet50 model for image classification | |
model = models.resnet50(pretrained=True) | |
model.eval() | |
# Transformation pipeline for image preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Load ImageNet labels | |
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
labels = json.loads(requests.get(LABELS_URL).text) | |
# Global chat history variable | |
chat_history = [] | |
def chat_with_gemini(message): | |
global chat_history | |
# Get a response from the language model | |
bot_response = llm.predict(message) # This will interact with the Gemini model | |
chat_history.append((message, bot_response)) | |
return chat_history | |
def analyze_image(image_path): | |
global chat_history | |
# Open, preprocess, and classify the image | |
image = Image.open(image_path).convert("RGB") | |
image_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
_, predicted_idx = outputs.max(1) | |
label = labels[predicted_idx.item()] | |
bot_response = f"The image seems to be: {label}." | |
chat_history.append(("Uploaded an image for analysis", bot_response)) | |
return chat_history | |
# Build the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Ken Chatbot") | |
gr.Markdown("Ask me anything or upload an image for analysis!") | |
# Chatbot display without "User" or "Bot" labels | |
chatbot = gr.Chatbot(elem_id="chatbot") | |
# User input components | |
msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False) | |
send_btn = gr.Button("Send") | |
img_upload = gr.Image(type="filepath", label="Upload an image for analysis") | |
# Define interactions | |
def handle_text_message(message): | |
return chat_with_gemini(message) | |
def handle_image_upload(image_path): | |
return analyze_image(image_path) | |
# Set up Gradio components with Enter key for sending | |
msg.submit(handle_text_message, msg, chatbot) | |
send_btn.click(handle_text_message, msg, chatbot) | |
send_btn.click(lambda: "", None, msg) # Clear input field | |
img_upload.change(handle_image_upload, img_upload, chatbot) | |
# Custom CSS for styling without usernames | |
gr.HTML(""" | |
<style> | |
#chatbot .message-container { | |
display: flex; | |
flex-direction: column; | |
margin-bottom: 10px; | |
max-width: 70%; | |
} | |
#chatbot .message { | |
border-radius: 15px; | |
padding: 10px; | |
margin: 5px 0; | |
word-wrap: break-word; | |
} | |
#chatbot .message.user { | |
background-color: #DCF8C6; | |
margin-left: auto; | |
text-align: right; | |
} | |
#chatbot .message.bot { | |
background-color: #E1E1E1; | |
margin-right: auto; | |
text-align: left; | |
} | |
</style> | |
""") | |
# Launch for Hugging Face Spaces | |
demo.launch() | |