kendrickfff commited on
Commit
7fb780e
1 Parent(s): ddb0e33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -43
app.py CHANGED
@@ -4,16 +4,20 @@ from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
4
  from PIL import Image
5
  import torch
6
  from torchvision import models, transforms
 
 
7
 
8
- # Set up the environment for Google Generative AI
9
  os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./firm-catalyst-437006-s4-407500537db5.json"
10
- llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
11
 
12
- # Load a pre-trained ResNet50 model for image analysis
 
 
 
13
  model = models.resnet50(pretrained=True)
14
- model.eval() # Set the model to evaluation mode
15
 
16
- # Define the transformation for the image
17
  transform = transforms.Compose([
18
  transforms.Resize(256),
19
  transforms.CenterCrop(224),
@@ -21,68 +25,62 @@ transform = transforms.Compose([
21
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
  ])
23
 
24
- # Load the ImageNet labels
25
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
26
- labels = None
27
 
28
- if not os.path.exists("imagenet_labels.json"):
29
- import requests
30
- response = requests.get(LABELS_URL)
31
- with open("imagenet_labels.json", "wb") as f:
32
- f.write(response.content)
33
-
34
- import json
35
- with open("imagenet_labels.json") as f:
36
- labels = json.load(f)
37
 
38
- def chat_with_gemini(message, chat_history):
39
- # Generate a response from the language model
 
40
  bot_response = llm.predict(message)
41
  chat_history.append((message, bot_response))
42
-
43
- return chat_history, chat_history
44
 
45
- def analyze_image(image_path, chat_history):
46
- # Load and preprocess the image
 
47
  image = Image.open(image_path).convert("RGB")
48
  image_tensor = transform(image).unsqueeze(0)
49
-
50
- # Predict the image class
51
  with torch.no_grad():
52
  outputs = model(image_tensor)
53
  _, predicted_idx = outputs.max(1)
54
 
55
- # Retrieve the label
56
  label = labels[predicted_idx.item()]
57
-
58
- # Respond with the classification result
59
  bot_response = f"The image seems to be: {label}."
60
  chat_history.append(("Uploaded an image for analysis", bot_response))
61
-
62
- return chat_history, chat_history
63
 
64
- # Create Gradio interface
65
- with gr.Blocks() as iface:
66
  gr.Markdown("# Ken Chatbot")
67
  gr.Markdown("Ask me anything or upload an image for analysis!")
68
 
69
- # Chatbot component without usernames
70
  chatbot = gr.Chatbot(elem_id="chatbot")
71
-
72
  # User input components
73
- msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...")
74
  send_btn = gr.Button("Send")
75
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
76
 
77
- # State for chat history
78
- state = gr.State([])
79
-
80
  # Define interactions
81
- send_btn.click(chat_with_gemini, [msg, state], [chatbot, state]) # Handle text input
82
- send_btn.click(lambda: "", None, msg) # Clear textbox
83
- img_upload.change(analyze_image, [img_upload, state], [chatbot, state]) # Handle image uploads
 
 
 
 
 
 
 
 
84
 
85
- # Custom CSS for styling chat bubbles without usernames
86
  gr.HTML("""
87
  <style>
88
  #chatbot .message-container {
@@ -110,5 +108,5 @@ with gr.Blocks() as iface:
110
  </style>
111
  """)
112
 
113
- # Launch the Gradio interface
114
- iface.launch(debug=True)
 
4
  from PIL import Image
5
  import torch
6
  from torchvision import models, transforms
7
+ import json
8
+ import requests
9
 
10
+ # Set the environment variable for Google Application Credentials
11
  os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./firm-catalyst-437006-s4-407500537db5.json"
 
12
 
13
+ # Initialize the chat model with Hugging Face-specific environment variables
14
+ llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
15
+
16
+ # Load a pre-trained ResNet50 model for image classification
17
  model = models.resnet50(pretrained=True)
18
+ model.eval()
19
 
20
+ # Transformation pipeline for image preprocessing
21
  transform = transforms.Compose([
22
  transforms.Resize(256),
23
  transforms.CenterCrop(224),
 
25
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
26
  ])
27
 
28
+ # Load ImageNet labels
29
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
30
+ labels = json.loads(requests.get(LABELS_URL).text)
31
 
32
+ # Global chat history variable
33
+ chat_history = []
 
 
 
 
 
 
 
34
 
35
+ def chat_with_gemini(message):
36
+ global chat_history
37
+ # Get a response from the language model
38
  bot_response = llm.predict(message)
39
  chat_history.append((message, bot_response))
40
+ return chat_history
 
41
 
42
+ def analyze_image(image_path):
43
+ global chat_history
44
+ # Open, preprocess, and classify the image
45
  image = Image.open(image_path).convert("RGB")
46
  image_tensor = transform(image).unsqueeze(0)
47
+
 
48
  with torch.no_grad():
49
  outputs = model(image_tensor)
50
  _, predicted_idx = outputs.max(1)
51
 
 
52
  label = labels[predicted_idx.item()]
 
 
53
  bot_response = f"The image seems to be: {label}."
54
  chat_history.append(("Uploaded an image for analysis", bot_response))
55
+ return chat_history
 
56
 
57
+ # Build the Gradio interface
58
+ with gr.Blocks() as demo:
59
  gr.Markdown("# Ken Chatbot")
60
  gr.Markdown("Ask me anything or upload an image for analysis!")
61
 
62
+ # Chatbot display without "User" or "Bot" labels
63
  chatbot = gr.Chatbot(elem_id="chatbot")
64
+
65
  # User input components
66
+ msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
67
  send_btn = gr.Button("Send")
68
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
69
 
 
 
 
70
  # Define interactions
71
+ def handle_text_message(message):
72
+ return chat_with_gemini(message)
73
+
74
+ def handle_image_upload(image_path):
75
+ return analyze_image(image_path)
76
+
77
+ # Set up Gradio components with Enter key for sending
78
+ msg.submit(handle_text_message, msg, chatbot)
79
+ send_btn.click(handle_text_message, msg, chatbot)
80
+ send_btn.click(lambda: "", None, msg) # Clear input field
81
+ img_upload.change(handle_image_upload, img_upload, chatbot)
82
 
83
+ # Custom CSS for styling without usernames
84
  gr.HTML("""
85
  <style>
86
  #chatbot .message-container {
 
108
  </style>
109
  """)
110
 
111
+ # Launch for Hugging Face Spaces
112
+ demo.launch()