capradeepgujaran commited on
Commit
bda20be
1 Parent(s): bd1163f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -79
app.py CHANGED
@@ -7,7 +7,6 @@ from PIL import Image as PILImage
7
  import io
8
  import os
9
  import base64
10
- import random
11
 
12
  def create_monitor_interface():
13
  api_key = os.getenv("GROQ_API_KEY")
@@ -16,26 +15,26 @@ def create_monitor_interface():
16
  def __init__(self):
17
  self.client = Groq()
18
  self.model_name = "llama-3.2-90b-vision-preview"
19
- self.max_image_size = (800, 800) # Increased size for better visibility
20
- self.colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
21
 
22
  def resize_image(self, image):
23
  height, width = image.shape[:2]
24
- aspect = width / height
25
-
26
- if width > height:
27
- new_width = min(self.max_image_size[0], width)
28
- new_height = int(new_width / aspect)
29
- else:
30
- new_height = min(self.max_image_size[1], height)
31
- new_width = int(new_height * aspect)
32
-
33
- return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
34
 
35
  def analyze_frame(self, frame: np.ndarray) -> str:
36
  if frame is None:
37
  return "No frame received"
38
-
39
  # Convert and resize image
40
  if len(frame.shape) == 2:
41
  frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
@@ -48,9 +47,9 @@ def create_monitor_interface():
48
  # High quality image for better analysis
49
  buffered = io.BytesIO()
50
  frame_pil.save(buffered,
51
- format="JPEG",
52
- quality=95,
53
- optimize=True)
54
  img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
55
  image_url = f"data:image/jpeg;base64,{img_base64}"
56
 
@@ -63,24 +62,24 @@ def create_monitor_interface():
63
  "content": [
64
  {
65
  "type": "text",
66
- "text": """Analyze this workplace image for safety conditions and hazards. Focus only on safety aspects such as:
67
-
68
- 1. Work posture and ergonomics at the shown position
69
- 2. Use of PPE and safety equipment
70
- 3. Tool handling and work techniques
71
- 4. Environmental conditions and surroundings
72
- 5. Equipment and machinery safety
73
- 6. Ground conditions and trip hazards
74
-
75
- Do not identify or describe any individuals. Instead, describe the safety conditions and actions observed.
76
-
77
- Format each safety observation as:
78
- - <location>position:safety condition description</location>
79
-
80
- Examples:
81
- - <location>center:Improper kneeling posture without knee protection, risking joint injury</location>
82
- - <location>left:Heavy machinery operating in close proximity to work area</location>
83
- - <location>bottom:Uneven ground surface creating trip hazard near work zone</location>"""
84
  },
85
  {
86
  "type": "image_url",
@@ -91,15 +90,48 @@ def create_monitor_interface():
91
  ]
92
  }
93
  ],
94
- temperature=0.7,
95
  max_tokens=500,
96
  stream=False
97
  )
98
  return completion.choices[0].message.content
99
  except Exception as e:
100
- print(f"Detailed error: {str(e)}")
101
  return f"Analysis Error: {str(e)}"
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def draw_observations(self, image, observations):
104
  """Draw accurate bounding boxes based on safety issue locations."""
105
  height, width = image.shape[:2]
@@ -110,7 +142,6 @@ def create_monitor_interface():
110
 
111
  def get_region_coordinates(position: str) -> tuple:
112
  """Get coordinates based on position description."""
113
- # Basic regions
114
  regions = {
115
  'center': (width//3, height//3, 2*width//3, 2*height//3),
116
  'background': (0, 0, width, height),
@@ -122,7 +153,9 @@ def create_monitor_interface():
122
  'bottom-left': (0, 2*height//3, width//3, height),
123
  'bottom': (width//3, 2*height//3, 2*width//3, height),
124
  'bottom-right': (2*width//3, 2*height//3, width, height),
125
- 'ground': (0, 2*height//3, width, height)
 
 
126
  }
127
 
128
  # Find best matching region
@@ -131,7 +164,7 @@ def create_monitor_interface():
131
  if key in position:
132
  return regions[key]
133
 
134
- return regions['center'] # Default to center if no match
135
 
136
  for idx, obs in enumerate(observations):
137
  color = self.colors[idx % len(self.colors)]
@@ -152,51 +185,17 @@ def create_monitor_interface():
152
 
153
  # Draw text background
154
  cv2.rectangle(image,
155
- (text_x, text_y - label_size[1] - padding),
156
- (text_x + label_size[0] + padding, text_y),
157
- color, -1)
158
 
159
  # Draw text
160
  cv2.putText(image, label,
161
- (text_x + padding//2, text_y - padding//2),
162
- font, font_scale, (255, 255, 255), thickness)
163
 
164
  return image
165
 
166
- def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
167
- if frame is None:
168
- return None, "No image provided"
169
-
170
- analysis = self.analyze_frame(frame)
171
- display_frame = frame.copy()
172
-
173
- # Parse observations from the formatted response
174
- observations = []
175
- lines = analysis.split('\n')
176
- for line in lines:
177
- # Look for location tags in the line
178
- if '<location>' in line and '</location>' in line:
179
- start = line.find('<location>') + len('<location>')
180
- end = line.find('</location>')
181
- location = line[start:end].strip()
182
-
183
- # Get the description that follows the location tag
184
- desc_start = line.find('</location>') + len('</location>:')
185
- description = line[desc_start:].strip()
186
-
187
- if location and description:
188
- observations.append({
189
- 'location': location,
190
- 'description': description
191
- })
192
-
193
- # Draw observations if we found any
194
- if observations:
195
- annotated_frame = self.draw_observations(display_frame, observations)
196
- return annotated_frame, analysis
197
-
198
- return display_frame, analysis
199
-
200
  # Create the main interface
201
  monitor = SafetyMonitor()
202
 
@@ -225,6 +224,13 @@ def create_monitor_interface():
225
  outputs=[output_image, analysis_text]
226
  )
227
 
 
 
 
 
 
 
 
228
  return demo
229
 
230
  demo = create_monitor_interface()
 
7
  import io
8
  import os
9
  import base64
 
10
 
11
  def create_monitor_interface():
12
  api_key = os.getenv("GROQ_API_KEY")
 
15
  def __init__(self):
16
  self.client = Groq()
17
  self.model_name = "llama-3.2-90b-vision-preview"
18
+ self.max_image_size = (800, 800)
19
+ self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
20
 
21
  def resize_image(self, image):
22
  height, width = image.shape[:2]
23
+ if height > self.max_image_size[1] or width > self.max_image_size[0]:
24
+ aspect = width / height
25
+ if width > height:
26
+ new_width = self.max_image_size[0]
27
+ new_height = int(new_width / aspect)
28
+ else:
29
+ new_height = self.max_image_size[1]
30
+ new_width = int(new_height * aspect)
31
+ return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
32
+ return image
33
 
34
  def analyze_frame(self, frame: np.ndarray) -> str:
35
  if frame is None:
36
  return "No frame received"
37
+
38
  # Convert and resize image
39
  if len(frame.shape) == 2:
40
  frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
 
47
  # High quality image for better analysis
48
  buffered = io.BytesIO()
49
  frame_pil.save(buffered,
50
+ format="JPEG",
51
+ quality=95,
52
+ optimize=True)
53
  img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
54
  image_url = f"data:image/jpeg;base64,{img_base64}"
55
 
 
62
  "content": [
63
  {
64
  "type": "text",
65
+ "text": """Analyze this workplace image for safety conditions and hazards. Focus on:
66
+
67
+ 1. Work posture and ergonomics
68
+ 2. PPE and safety equipment usage
69
+ 3. Tool handling and techniques
70
+ 4. Environmental conditions
71
+ 5. Equipment and machinery safety
72
+ 6. Ground conditions and hazards
73
+
74
+ Describe each safety condition observed, using this exact format:
75
+ - <location>position</location>: detailed safety observation
76
+
77
+ Examples:
78
+ - <location>center</location>: Improper kneeling posture without knee protection, risking joint injury
79
+ - <location>background</location>: Heavy machinery operating in close proximity creating hazard zone
80
+ - <location>ground</location>: Uneven surface and debris creating trip hazards
81
+
82
+ Be specific about locations and safety concerns."""
83
  },
84
  {
85
  "type": "image_url",
 
90
  ]
91
  }
92
  ],
93
+ temperature=0.5,
94
  max_tokens=500,
95
  stream=False
96
  )
97
  return completion.choices[0].message.content
98
  except Exception as e:
99
+ print(f"Analysis error: {str(e)}")
100
  return f"Analysis Error: {str(e)}"
101
 
102
+ def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
103
+ if frame is None:
104
+ return None, "No image provided"
105
+
106
+ analysis = self.analyze_frame(frame)
107
+ display_frame = frame.copy()
108
+
109
+ # Parse observations from the formatted response
110
+ observations = []
111
+ lines = analysis.split('\n')
112
+ for line in lines:
113
+ if '<location>' in line and '</location>' in line:
114
+ start = line.find('<location>') + len('<location>')
115
+ end = line.find('</location>')
116
+ location = line[start:end].strip()
117
+
118
+ # Get the description that follows the location tags
119
+ desc_start = line.find('</location>') + len('</location>:')
120
+ description = line[desc_start:].strip()
121
+
122
+ if location and description:
123
+ observations.append({
124
+ 'location': location,
125
+ 'description': description
126
+ })
127
+
128
+ # Draw observations if we found any
129
+ if observations:
130
+ annotated_frame = self.draw_observations(display_frame, observations)
131
+ return annotated_frame, analysis
132
+
133
+ return display_frame, analysis
134
+
135
  def draw_observations(self, image, observations):
136
  """Draw accurate bounding boxes based on safety issue locations."""
137
  height, width = image.shape[:2]
 
142
 
143
  def get_region_coordinates(position: str) -> tuple:
144
  """Get coordinates based on position description."""
 
145
  regions = {
146
  'center': (width//3, height//3, 2*width//3, 2*height//3),
147
  'background': (0, 0, width, height),
 
153
  'bottom-left': (0, 2*height//3, width//3, height),
154
  'bottom': (width//3, 2*height//3, 2*width//3, height),
155
  'bottom-right': (2*width//3, 2*height//3, width, height),
156
+ 'ground': (0, 2*height//3, width, height),
157
+ 'machinery': (0, 0, width//2, height),
158
+ 'work-area': (width//4, height//4, 3*width//4, 3*height//4)
159
  }
160
 
161
  # Find best matching region
 
164
  if key in position:
165
  return regions[key]
166
 
167
+ return regions['center']
168
 
169
  for idx, obs in enumerate(observations):
170
  color = self.colors[idx % len(self.colors)]
 
185
 
186
  # Draw text background
187
  cv2.rectangle(image,
188
+ (text_x, text_y - label_size[1] - padding),
189
+ (text_x + label_size[0] + padding, text_y),
190
+ color, -1)
191
 
192
  # Draw text
193
  cv2.putText(image, label,
194
+ (text_x + padding//2, text_y - padding//2),
195
+ font, font_scale, (255, 255, 255), thickness)
196
 
197
  return image
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  # Create the main interface
200
  monitor = SafetyMonitor()
201
 
 
224
  outputs=[output_image, analysis_text]
225
  )
226
 
227
+ gr.Markdown("""
228
+ ## Instructions:
229
+ 1. Upload an image to analyze safety conditions
230
+ 2. View annotated results showing safety concerns
231
+ 3. Read detailed analysis of identified issues
232
+ """)
233
+
234
  return demo
235
 
236
  demo = create_monitor_interface()