capradeepgujaran commited on
Commit
740f7c7
1 Parent(s): 7e6153d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -57
app.py CHANGED
@@ -16,27 +16,32 @@ def create_monitor_interface():
16
  def __init__(self):
17
  self.client = Groq()
18
  self.model_name = "llama-3.2-90b-vision-preview"
19
- self.max_image_size = (640, 640) # Increased size for better visibility
20
- self.colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
 
 
 
21
 
22
  def resize_image(self, image):
23
  height, width = image.shape[:2]
24
- aspect = width / height
25
 
26
- if width > height:
27
- new_width = min(self.max_image_size[0], width)
28
- new_height = int(new_width / aspect)
29
- else:
30
- new_height = min(self.max_image_size[1], height)
31
- new_width = int(new_height * aspect)
32
-
33
- return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
 
 
 
34
 
35
  def analyze_frame(self, frame: np.ndarray) -> str:
36
  if frame is None:
37
  return "No frame received"
38
 
39
- # Convert and resize image
40
  if len(frame.shape) == 2:
41
  frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
42
  elif len(frame.shape) == 3 and frame.shape[2] == 4:
@@ -45,11 +50,11 @@ def create_monitor_interface():
45
  frame = self.resize_image(frame)
46
  frame_pil = PILImage.fromarray(frame)
47
 
48
- # Convert to base64 with minimal quality
49
  buffered = io.BytesIO()
50
  frame_pil.save(buffered,
51
  format="JPEG",
52
- quality=30,
53
  optimize=True)
54
  img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
55
  image_url = f"data:image/jpeg;base64,{img_base64}"
@@ -63,9 +68,10 @@ def create_monitor_interface():
63
  "content": [
64
  {
65
  "type": "text",
66
- "text": """Analyze this workplace image and describe each safety concern in this format:
67
- - <location>Description</location>
68
- Use one line per issue, starting with a dash and location in tags."""
 
69
  },
70
  {
71
  "type": "image_url",
@@ -81,7 +87,7 @@ def create_monitor_interface():
81
  }
82
  ],
83
  temperature=0.1,
84
- max_tokens=150,
85
  top_p=1,
86
  stream=False,
87
  stop=None
@@ -91,59 +97,97 @@ def create_monitor_interface():
91
  print(f"Detailed error: {str(e)}")
92
  return f"Analysis Error: {str(e)}"
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def draw_observations(self, image, observations):
95
  height, width = image.shape[:2]
96
  font = cv2.FONT_HERSHEY_SIMPLEX
97
- font_scale = 0.5
98
  thickness = 2
99
 
100
- # Generate random positions for each observation
101
  for idx, obs in enumerate(observations):
102
  color = self.colors[idx % len(self.colors)]
103
 
104
- # Generate random box position
105
- box_width = width // 3
106
- box_height = height // 3
107
- x = random.randint(0, width - box_width)
108
- y = random.randint(0, height - box_height)
 
 
 
 
 
 
109
 
110
  # Draw rectangle
111
- cv2.rectangle(image, (x, y), (x + box_width, y + box_height), color, 2)
112
 
113
  # Add label with background
114
- label = obs[:40] + "..." if len(obs) > 40 else obs
115
  label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
116
- cv2.rectangle(image, (x, y - 20), (x + label_size[0], y), color, -1)
117
- cv2.putText(image, label, (x, y - 5), font, font_scale, (255, 255, 255), thickness)
 
 
 
 
 
 
 
118
 
119
  return image
120
 
121
  def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
122
  if frame is None:
123
  return None, "No image provided"
124
-
125
- analysis = self.analyze_frame(frame)
126
- display_frame = self.resize_image(frame.copy())
127
 
128
- # Parse observations from the analysis
129
- observations = []
130
- for line in analysis.split('\n'):
131
- line = line.strip()
132
- if line.startswith('-'):
133
- # Extract text between <location> tags if present
134
- if '<location>' in line and '</location>' in line:
135
- start = line.find('<location>') + len('<location>')
136
- end = line.find('</location>')
137
- observation = line[end + len('</location>'):].strip()
138
- else:
139
- observation = line[1:].strip() # Remove the dash
140
- if observation:
141
- observations.append(observation)
 
 
 
 
 
 
142
 
143
- # Draw observations on the image
144
- annotated_frame = self.draw_observations(display_frame, observations)
 
145
 
146
- return annotated_frame, analysis
147
 
148
  # Create the main interface
149
  monitor = SafetyMonitor()
@@ -152,12 +196,12 @@ def create_monitor_interface():
152
  gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
153
 
154
  with gr.Row():
155
- input_image = gr.Image(label="Upload Image")
156
- output_image = gr.Image(label="Annotated Results")
157
 
158
- analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
159
 
160
- def analyze_image(image):
161
  if image is None:
162
  return None, "No image provided"
163
  try:
@@ -167,10 +211,10 @@ def create_monitor_interface():
167
  print(f"Processing error: {str(e)}")
168
  return None, f"Error processing image: {str(e)}"
169
 
170
- input_image.change(
171
- fn=analyze_image,
172
- inputs=input_image,
173
- outputs=[output_image, analysis_text]
174
  )
175
 
176
  return demo
 
16
  def __init__(self):
17
  self.client = Groq()
18
  self.model_name = "llama-3.2-90b-vision-preview"
19
+ self.max_image_size = (800, 800) # Increased size for better quality
20
+ self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
21
+ self.last_analysis_time = 0
22
+ self.analysis_interval = 2 # Analyze every 2 seconds
23
+ self.last_observations = [] # Store previous observations
24
 
25
  def resize_image(self, image):
26
  height, width = image.shape[:2]
 
27
 
28
+ # Only resize if image is too large
29
+ if height > self.max_image_size[1] or width > self.max_image_size[0]:
30
+ aspect = width / height
31
+ if width > height:
32
+ new_width = self.max_image_size[0]
33
+ new_height = int(new_width / aspect)
34
+ else:
35
+ new_height = self.max_image_size[1]
36
+ new_width = int(new_height * aspect)
37
+ return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
38
+ return image
39
 
40
  def analyze_frame(self, frame: np.ndarray) -> str:
41
  if frame is None:
42
  return "No frame received"
43
 
44
+ # Convert image
45
  if len(frame.shape) == 2:
46
  frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
47
  elif len(frame.shape) == 3 and frame.shape[2] == 4:
 
50
  frame = self.resize_image(frame)
51
  frame_pil = PILImage.fromarray(frame)
52
 
53
+ # Convert to base64 with better quality
54
  buffered = io.BytesIO()
55
  frame_pil.save(buffered,
56
  format="JPEG",
57
+ quality=85, # Higher quality
58
  optimize=True)
59
  img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
60
  image_url = f"data:image/jpeg;base64,{img_base64}"
 
68
  "content": [
69
  {
70
  "type": "text",
71
+ "text": """Analyze this image for safety hazards. For each issue, describe:
72
+ 1. The location (top-left, center, bottom-right, etc.)
73
+ 2. The specific safety concern
74
+ Format: - <location>position:description</location>"""
75
  },
76
  {
77
  "type": "image_url",
 
87
  }
88
  ],
89
  temperature=0.1,
90
+ max_tokens=200,
91
  top_p=1,
92
  stream=False,
93
  stop=None
 
97
  print(f"Detailed error: {str(e)}")
98
  return f"Analysis Error: {str(e)}"
99
 
100
+ def get_region_coordinates(self, position: str, image_shape: tuple) -> tuple:
101
+ height, width = image_shape[:2]
102
+ regions = {
103
+ 'top-left': (0, 0, width//3, height//3),
104
+ 'top': (width//3, 0, 2*width//3, height//3),
105
+ 'top-right': (2*width//3, 0, width, height//3),
106
+ 'left': (0, height//3, width//3, 2*height//3),
107
+ 'center': (width//3, height//3, 2*width//3, 2*height//3),
108
+ 'right': (2*width//3, height//3, width, 2*height//3),
109
+ 'bottom-left': (0, 2*height//3, width//3, height),
110
+ 'bottom': (width//3, 2*height//3, 2*width//3, height),
111
+ 'bottom-right': (2*width//3, 2*height//3, width, height)
112
+ }
113
+
114
+ # Find the best matching region
115
+ for region_name, coords in regions.items():
116
+ if region_name in position.lower():
117
+ return coords
118
+
119
+ # Default to center if no match
120
+ return regions['center']
121
+
122
  def draw_observations(self, image, observations):
123
  height, width = image.shape[:2]
124
  font = cv2.FONT_HERSHEY_SIMPLEX
125
+ font_scale = 0.6
126
  thickness = 2
127
 
 
128
  for idx, obs in enumerate(observations):
129
  color = self.colors[idx % len(self.colors)]
130
 
131
+ # Try to extract position from observation
132
+ parts = obs.split(':')
133
+ if len(parts) >= 2:
134
+ position = parts[0]
135
+ description = ':'.join(parts[1:])
136
+ else:
137
+ position = 'center'
138
+ description = obs
139
+
140
+ # Get coordinates based on position
141
+ x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
142
 
143
  # Draw rectangle
144
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
145
 
146
  # Add label with background
147
+ label = description[:50] + "..." if len(description) > 50 else description
148
  label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
149
+
150
+ # Ensure label stays within image bounds
151
+ label_x = max(0, min(x1, width - label_size[0]))
152
+ label_y = max(20, y1 - 5)
153
+
154
+ cv2.rectangle(image, (label_x, label_y - 20),
155
+ (label_x + label_size[0], label_y), color, -1)
156
+ cv2.putText(image, label, (label_x, label_y - 5),
157
+ font, font_scale, (255, 255, 255), thickness)
158
 
159
  return image
160
 
161
  def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
162
  if frame is None:
163
  return None, "No image provided"
 
 
 
164
 
165
+ current_time = time.time()
166
+
167
+ # Only perform analysis if enough time has passed
168
+ if current_time - self.last_analysis_time >= self.analysis_interval:
169
+ analysis = self.analyze_frame(frame)
170
+ self.last_analysis_time = current_time
171
+
172
+ # Parse observations
173
+ observations = []
174
+ for line in analysis.split('\n'):
175
+ line = line.strip()
176
+ if line.startswith('-'):
177
+ if '<location>' in line and '</location>' in line:
178
+ start = line.find('<location>') + len('<location>')
179
+ end = line.find('</location>')
180
+ observation = line[start:end].strip()
181
+ if observation:
182
+ observations.append(observation)
183
+
184
+ self.last_observations = observations
185
 
186
+ # Draw observations on the frame
187
+ display_frame = frame.copy()
188
+ annotated_frame = self.draw_observations(display_frame, self.last_observations)
189
 
190
+ return annotated_frame, '\n'.join([f"- {obs}" for obs in self.last_observations])
191
 
192
  # Create the main interface
193
  monitor = SafetyMonitor()
 
196
  gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
197
 
198
  with gr.Row():
199
+ webcam = gr.Image(source="webcam", streaming=True, label="Live Feed")
200
+ output_image = gr.Image(label="Analysis")
201
 
202
+ analysis_text = gr.Textbox(label="Safety Concerns", lines=5)
203
 
204
+ def analyze_stream(image):
205
  if image is None:
206
  return None, "No image provided"
207
  try:
 
211
  print(f"Processing error: {str(e)}")
212
  return None, f"Error processing image: {str(e)}"
213
 
214
+ webcam.stream(
215
+ fn=analyze_stream,
216
+ outputs=[output_image, analysis_text],
217
+ show_progress=False
218
  )
219
 
220
  return demo