kendrickfff commited on
Commit
f80c42d
1 Parent(s): 40cc7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -42
app.py CHANGED
@@ -1,19 +1,19 @@
1
- import os
2
  import gradio as gr
3
  from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
4
  from PIL import Image
5
  import torch
6
  from torchvision import models, transforms
 
 
7
 
8
- # Set up the environment for Google Generative AI
9
- os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./firm-catalyst-437006-s4-407500537db5.json"
10
- llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
11
 
12
- # Load a pre-trained ResNet50 model for image analysis
13
  model = models.resnet50(pretrained=True)
14
- model.eval() # Set the model to evaluation mode
15
 
16
- # Define the transformation for the image
17
  transform = transforms.Compose([
18
  transforms.Resize(256),
19
  transforms.CenterCrop(224),
@@ -21,68 +21,64 @@ transform = transforms.Compose([
21
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
  ])
23
 
24
- # Load the ImageNet labels
25
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
26
- labels = None
27
-
28
- if not os.path.exists("imagenet_labels.json"):
29
- import requests
30
- response = requests.get(LABELS_URL)
31
- with open("imagenet_labels.json", "wb") as f:
32
- f.write(response.content)
33
-
34
- import json
35
- with open("imagenet_labels.json") as f:
36
- labels = json.load(f)
37
 
38
  def chat_with_gemini(message, chat_history):
39
- # Generate a response from the language model
40
  bot_response = llm.predict(message)
41
  chat_history.append((message, bot_response))
42
-
43
- return chat_history, chat_history
44
 
45
  def analyze_image(image_path, chat_history):
46
- # Load and preprocess the image
47
  image = Image.open(image_path).convert("RGB")
48
  image_tensor = transform(image).unsqueeze(0)
49
-
50
- # Predict the image class
51
  with torch.no_grad():
52
  outputs = model(image_tensor)
53
  _, predicted_idx = outputs.max(1)
54
 
55
- # Retrieve the label
56
  label = labels[predicted_idx.item()]
57
-
58
- # Respond with the classification result
59
  bot_response = f"The image seems to be: {label}."
60
  chat_history.append(("Uploaded an image for analysis", bot_response))
61
-
62
- return chat_history, chat_history
63
 
64
- # Create Gradio interface
65
- with gr.Blocks() as iface:
66
  gr.Markdown("# Ken Chatbot")
67
  gr.Markdown("Ask me anything or upload an image for analysis!")
68
 
69
- # Chatbot component without usernames
70
  chatbot = gr.Chatbot(elem_id="chatbot")
71
-
72
  # User input components
73
- msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...")
74
  send_btn = gr.Button("Send")
75
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
76
 
77
- # State for chat history
78
- state = gr.State([])
79
 
80
  # Define interactions
81
- send_btn.click(chat_with_gemini, [msg, state], [chatbot, state]) # Handle text input
82
- send_btn.click(lambda: "", None, msg) # Clear textbox
83
- img_upload.change(analyze_image, [img_upload, state], [chatbot, state]) # Handle image uploads
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Custom CSS for styling chat bubbles without usernames
86
  gr.HTML("""
87
  <style>
88
  #chatbot .message-container {
@@ -110,5 +106,4 @@ with gr.Blocks() as iface:
110
  </style>
111
  """)
112
 
113
- # Launch the Gradio interface
114
  iface.launch(debug=True)
 
 
1
  import gradio as gr
2
  from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
3
  from PIL import Image
4
  import torch
5
  from torchvision import models, transforms
6
+ import json
7
+ import requests
8
 
9
+ # Initialize the chat model with Hugging Face-specific environment variables
10
+ llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
 
11
 
12
+ # Load a pre-trained ResNet50 model for image classification
13
  model = models.resnet50(pretrained=True)
14
+ model.eval()
15
 
16
+ # Transformation pipeline for image preprocessing
17
  transform = transforms.Compose([
18
  transforms.Resize(256),
19
  transforms.CenterCrop(224),
 
21
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
  ])
23
 
24
+ # Load ImageNet labels
25
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
26
+ labels = json.loads(requests.get(LABELS_URL).text)
 
 
 
 
 
 
 
 
 
 
27
 
28
  def chat_with_gemini(message, chat_history):
29
+ # Get a response from the language model
30
  bot_response = llm.predict(message)
31
  chat_history.append((message, bot_response))
32
+ return chat_history
 
33
 
34
  def analyze_image(image_path, chat_history):
35
+ # Open, preprocess, and classify the image
36
  image = Image.open(image_path).convert("RGB")
37
  image_tensor = transform(image).unsqueeze(0)
38
+
 
39
  with torch.no_grad():
40
  outputs = model(image_tensor)
41
  _, predicted_idx = outputs.max(1)
42
 
 
43
  label = labels[predicted_idx.item()]
 
 
44
  bot_response = f"The image seems to be: {label}."
45
  chat_history.append(("Uploaded an image for analysis", bot_response))
46
+ return chat_history
 
47
 
48
+ # Build the Gradio interface
49
+ with gr.Blocks() as demo:
50
  gr.Markdown("# Ken Chatbot")
51
  gr.Markdown("Ask me anything or upload an image for analysis!")
52
 
53
+ # Chatbot display without "User" or "Bot" labels
54
  chatbot = gr.Chatbot(elem_id="chatbot")
55
+
56
  # User input components
57
+ msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
58
  send_btn = gr.Button("Send")
59
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
60
 
61
+ # Local chat history state
62
+ chat_history = []
63
 
64
  # Define interactions
65
+ def handle_text_message(message):
66
+ nonlocal chat_history
67
+ chat_history = chat_with_gemini(message, chat_history)
68
+ return chat_history
69
+
70
+ def handle_image_upload(image_path):
71
+ nonlocal chat_history
72
+ chat_history = analyze_image(image_path, chat_history)
73
+ return chat_history
74
+
75
+ # Set up Gradio components with Enter key for sending
76
+ msg.submit(handle_text_message, msg, chatbot)
77
+ send_btn.click(handle_text_message, msg, chatbot)
78
+ send_btn.click(lambda: "", None, msg) # Clear input field
79
+ img_upload.change(handle_image_upload, img_upload, chatbot)
80
 
81
+ # Custom CSS for styling without usernames
82
  gr.HTML("""
83
  <style>
84
  #chatbot .message-container {
 
106
  </style>
107
  """)
108
 
 
109
  iface.launch(debug=True)