kendrickfff commited on
Commit
b05e484
·
verified ·
1 Parent(s): 61993c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -114
app.py CHANGED
@@ -1,35 +1,12 @@
1
  import os
2
  import gradio as gr
 
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
- from langchain_google_genai.chat_models import ChatGoogleGenerativeAI # Import Gemini
5
  from PIL import Image
6
- import torch
7
- import json
8
  import requests
 
9
 
10
- # Load credentials (stringified JSON) from environment variable for Gemini
11
- credentials_string = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
12
- if not credentials_string:
13
- raise ValueError("GOOGLE_APPLICATION_CREDENTIALS is not set in the environment!")
14
-
15
- # Parse the stringified JSON back to a Python dictionary
16
- credentials = json.loads(credentials_string)
17
-
18
- # Save the credentials to a temporary JSON file (required by Google SDKs)
19
- with open("service_account.json", "w") as f:
20
- json.dump(credentials, f)
21
-
22
- # Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the temporary file
23
- os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account.json"
24
-
25
- # Initialize Gemini model (chatbot)
26
- llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
27
-
28
- # Initialize DETR model and processor for object detection
29
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
30
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
31
-
32
- # Load COCO class label
33
  COCO_CLASSES = [
34
  'airplane', 'apple', 'backpack', 'banana', 'baseball hat', 'baseball glove', 'bear', 'bed', 'bench', 'bicycle',
35
  'bird', 'boat', 'book', 'bottle', 'bowl', 'broccoli', 'bus', 'cake', 'car', 'carrot', 'cat', 'cell phone', 'chair',
@@ -41,105 +18,64 @@ COCO_CLASSES = [
41
  'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase', 'wine glass'
42
  ]
43
 
 
 
 
44
 
45
- # Global chat history variable
46
- chat_history = []
47
-
48
- # Function for chatting with Gemini
49
- def chat_with_gemini(message):
50
- global chat_history
51
- bot_response = llm.predict(message) # This will interact with the Gemini model
52
- chat_history.append((message, bot_response))
53
- return chat_history
54
-
55
- # Function for analyzing the uploaded image
56
  def analyze_image(image_path):
57
- global chat_history
58
  try:
59
- # Open and preprocess the image
60
- image = Image.open(image_path).convert("RGB")
61
- inputs = processor(images=image, return_tensors="pt")
62
-
63
- # Perform inference
64
- with torch.no_grad():
65
- outputs = model(**inputs)
66
-
67
- # Set a target size for post-processing
68
- target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
69
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
70
 
71
- # Collect detected objects
72
- detected_objects = []
73
- for idx, label in enumerate(results["labels"]):
74
- # Get the object label based on label index
75
- object_name = COCO_CLASSES[label.item()] # Assuming COCO_CLASSES is available
76
- detected_objects.append(object_name)
77
 
78
- if detected_objects:
79
- bot_response = f"Objects detected: {', '.join(detected_objects)}."
80
- else:
81
- bot_response = "No objects detected."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- chat_history.append(("Uploaded an image for analysis", bot_response))
84
- return chat_history
85
  except Exception as e:
86
- error_msg = f"Error processing the image: {str(e)}"
87
- chat_history.append(("Error during image analysis", error_msg))
88
- return chat_history
89
 
90
- # Build the Gradio interface
91
- with gr.Blocks() as demo:
92
- gr.Markdown("# Ken Chatbot")
93
- gr.Markdown("Ask me anything or upload an image for analysis!")
94
 
95
- # Chatbot display without "User" or "Bot" labels
96
- chatbot = gr.Chatbot(elem_id="chatbot")
 
 
97
 
98
  # User input components
99
- msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
100
- send_btn = gr.Button("Send")
101
- img_upload = gr.Image(type="filepath", label="Upload an image for analysis (Only detect 80 types of images recognized from COCO dataset. Check the list on https://blog.roboflow.com/microsoft-coco-classes/")
102
-
103
- # Define interactions
104
- def handle_text_message(message):
105
- return chat_with_gemini(message)
106
-
107
- def handle_image_upload(image_path):
108
- return analyze_image(image_path)
109
-
110
- # Set up Gradio components with Enter key for sending
111
- msg.submit(handle_text_message, msg, chatbot)
112
- send_btn.click(handle_text_message, msg, chatbot)
113
- send_btn.click(lambda: "", None, msg) # Clear input field
114
- img_upload.change(handle_image_upload, img_upload, chatbot)
115
 
116
- # Custom CSS for styling without usernames
117
- gr.HTML("""
118
- <style>
119
- #chatbot .message-container {
120
- display: flex;
121
- flex-direction: column;
122
- margin-bottom: 10px;
123
- max-width: 70%;
124
- }
125
- #chatbot .message {
126
- border-radius: 15px;
127
- padding: 10px;
128
- margin: 5px 0;
129
- word-wrap: break-word;
130
- }
131
- #chatbot .message.user {
132
- background-color: #DCF8C6;
133
- margin-left: auto;
134
- text-align: right;
135
- }
136
- #chatbot .message.bot {
137
- background-color: #E1E1E1;
138
- margin-right: auto;
139
- text-align: left;
140
- }
141
- </style>
142
- """)
143
 
144
- # Launch the Gradio interface
145
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
+ import torch
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
 
5
  from PIL import Image
 
 
6
  import requests
7
+ import json
8
 
9
+ # Custom Object Labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  COCO_CLASSES = [
11
  'airplane', 'apple', 'backpack', 'banana', 'baseball hat', 'baseball glove', 'bear', 'bed', 'bench', 'bicycle',
12
  'bird', 'boat', 'book', 'bottle', 'bowl', 'broccoli', 'bus', 'cake', 'car', 'carrot', 'cat', 'cell phone', 'chair',
 
18
  'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase', 'wine glass'
19
  ]
20
 
21
+ # Load the DETR model and processor
22
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
23
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
24
 
25
+ # Initialize Gradio interface
 
 
 
 
 
 
 
 
 
 
26
  def analyze_image(image_path):
 
27
  try:
28
+ # Open the image
29
+ image = Image.open(image_path)
 
 
 
 
 
 
 
 
 
30
 
31
+ # Preprocess the image
32
+ inputs = processor(images=image, return_tensors="pt")
 
 
 
 
33
 
34
+ # Perform object detection
35
+ outputs = model(**inputs)
36
+
37
+ # Get the logits (class predictions) and boxes (bounding boxes)
38
+ logits = outputs.logits
39
+ boxes = outputs.pred_boxes
40
+
41
+ # Get the predicted labels (class IDs)
42
+ class_ids = logits.argmax(-1)
43
+
44
+ # Filter out detections with low confidence and map to custom labels
45
+ results = []
46
+ for idx, class_id in enumerate(class_ids[0]):
47
+ confidence = logits[0, idx, class_id].item()
48
+ if confidence > 0.5: # Confidence threshold
49
+ label = COCO_CLASSES[class_id]
50
+ box = boxes[0, idx].tolist()
51
+ results.append({
52
+ 'label': label,
53
+ 'confidence': confidence,
54
+ 'box': box
55
+ })
56
+
57
+ if len(results) == 0:
58
+ return "No objects detected."
59
+
60
+ # Generate a response with the detected objects
61
+ detected_objects = "\n".join([f"{result['label']} (confidence: {result['confidence']:.2f})" for result in results])
62
+ return f"Detected Objects:\n{detected_objects}"
63
 
 
 
64
  except Exception as e:
65
+ return f"Error processing the image: {str(e)}"
 
 
66
 
 
 
 
 
67
 
68
+ # Gradio Interface Setup
69
+ with gr.Blocks() as demo:
70
+ gr.Markdown("## Object Detection with Custom Labels")
71
+ gr.Markdown("Upload an image for analysis!")
72
 
73
  # User input components
74
+ img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
75
+ output_text = gr.Textbox(label="Detection Results", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Define the interaction
78
+ img_upload.change(analyze_image, img_upload, output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # Launch the interface
81
  demo.launch()