import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image as PILImage
import io
import os
import base64
class SafetyMonitor:
def __init__(self):
"""Initialize Safety Monitor with configuration."""
self.client = Groq()
self.model_name = "llama-3.2-90b-vision-preview"
self.max_image_size = (800, 800)
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
def preprocess_image(self, frame):
"""Process image for analysis."""
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif len(frame.shape) == 3 and frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
return self.resize_image(frame)
def resize_image(self, image):
"""Resize image while maintaining aspect ratio."""
height, width = image.shape[:2]
if height > self.max_image_size[1] or width > self.max_image_size[0]:
aspect = width / height
if width > height:
new_width = self.max_image_size[0]
new_height = int(new_width / aspect)
else:
new_height = self.max_image_size[1]
new_width = int(new_height * aspect)
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return image
def encode_image(self, frame):
"""Convert image to base64 encoding."""
frame_pil = PILImage.fromarray(frame)
buffered = io.BytesIO()
frame_pil.save(buffered, format="JPEG", quality=95)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return f"data:image/jpeg;base64,{img_base64}"
def get_scene_context(self, image):
"""Analyze the scene context."""
try:
image_url = self.encode_image(image)
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """Analyze this workplace image and identify key areas and elements. Include:
1. Worker locations and activities
2. Equipment and machinery
3. Materials and storage
4. Access routes and paths
5. Hazardous areas
Format each observation as:
- Element: specific location in image"""
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
}
],
temperature=0.3,
max_tokens=200,
stream=False
)
return completion.choices[0].message.content
except Exception as e:
print(f"Scene analysis error: {str(e)}")
return ""
def analyze_frame(self, frame):
"""Perform safety analysis on the frame."""
if frame is None:
return "No frame received", {}
frame = self.preprocess_image(frame)
image_url = self.encode_image(frame)
try:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """Analyze this image for safety hazards. For each hazard:
1. Specify the precise location in the image
2. Describe the safety concern or violation
3. Indicate the potential risk
Format each finding as:
- position:detailed safety concern
Look for all types of safety issues:
- PPE compliance
- Ergonomic risks
- Equipment safety
- Environmental hazards
- Material handling
- Work procedures
- Access and egress
- Housekeeping"""
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
}
],
temperature=0.5,
max_tokens=500,
stream=False
)
return completion.choices[0].message.content, {}
except Exception as e:
print(f"Analysis error: {str(e)}")
return f"Analysis Error: {str(e)}", {}
def get_region_coordinates(self, position, image_shape):
"""Convert textual position to coordinates."""
height, width = image_shape[:2]
# Parse position for spatial information
position = position.lower()
# Base coordinates (full image)
x1, y1, x2, y2 = 0, 0, width, height
# Define regions
regions = {
'center': (width//3, height//3, 2*width//3, 2*height//3),
'top': (width//3, 0, 2*width//3, height//3),
'bottom': (width//3, 2*height//3, 2*width//3, height),
'left': (0, height//3, width//3, 2*height//3),
'right': (2*width//3, height//3, width, 2*height//3),
'top-left': (0, 0, width//3, height//3),
'top-right': (2*width//3, 0, width, height//3),
'bottom-left': (0, 2*height//3, width//3, height),
'bottom-right': (2*width//3, 2*height//3, width, height),
'upper': (0, 0, width, height//2),
'lower': (0, height//2, width, height),
'middle': (0, height//3, width, 2*height//3)
}
# Find best matching region
best_match = None
max_match = 0
for region, coords in regions.items():
if region in position:
words = region.split('-')
matches = sum(1 for word in words if word in position)
if matches > max_match:
max_match = matches
best_match = coords
return best_match if best_match else (x1, y1, x2, y2)
def draw_observations(self, image, observations):
"""Draw bounding boxes and labels for safety observations."""
height, width = image.shape[:2]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 2
padding = 10
for idx, obs in enumerate(observations):
color = self.colors[idx % len(self.colors)]
# Get coordinates for this observation
x1, y1, x2, y2 = self.get_region_coordinates(obs['location'], image.shape)
# Draw rectangle
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
# Add label with background
label = obs['description'][:50] + "..." if len(obs['description']) > 50 else obs['description']
label_size, _ = cv2.getTextSize(label, font, font_scale, thickness)
# Position text above the box
text_x = max(0, x1)
text_y = max(label_size[1] + padding, y1 - padding)
# Draw text background
cv2.rectangle(image,
(text_x, text_y - label_size[1] - padding),
(text_x + label_size[0] + padding, text_y),
color, -1)
# Draw text
cv2.putText(image, label,
(text_x + padding//2, text_y - padding//2),
font, font_scale, (255, 255, 255), thickness)
return image
def process_frame(self, frame):
"""Main processing pipeline for safety analysis."""
if frame is None:
return None, "No image provided"
try:
# Get analysis
analysis, _ = self.analyze_frame(frame)
display_frame = frame.copy()
# Parse observations
observations = []
for line in analysis.split('\n'):
line = line.strip()
if line.startswith('-') and '' in line and '' in line:
start = line.find('') + len('')
end = line.find('')
location_description = line[start:end].strip()
if ':' in location_description:
location, description = location_description.split(':', 1)
observations.append({
'location': location.strip(),
'description': description.strip()
})
# Draw observations
if observations:
annotated_frame = self.draw_observations(display_frame, observations)
return annotated_frame, analysis
return display_frame, analysis
except Exception as e:
print(f"Processing error: {str(e)}")
return None, f"Error processing image: {str(e)}"
def create_monitor_interface():
monitor = SafetyMonitor()
with gr.Blocks() as demo:
gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
with gr.Row():
input_image = gr.Image(label="Upload Image")
output_image = gr.Image(label="Safety Analysis")
analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
def analyze_image(image):
if image is None:
return None, "No image provided"
try:
processed_frame, analysis = monitor.process_frame(image)
return processed_frame, analysis
except Exception as e:
print(f"Processing error: {str(e)}")
return None, f"Error processing image: {str(e)}"
input_image.change(
fn=analyze_image,
inputs=input_image,
outputs=[output_image, analysis_text]
)
gr.Markdown("""
## Instructions:
1. Upload any workplace/safety-related image
2. View identified hazards and their locations
3. Read detailed analysis of safety concerns
""")
return demo
if __name__ == "__main__":
demo = create_monitor_interface()
demo.launch()