File size: 7,137 Bytes
7b04d4e
 
 
 
64b8a91
49a323c
7b04d4e
32acaa8
64b8a91
 
a5f647b
771e08a
5b41d95
1cddd79
5b41d95
 
 
 
64b8a91
 
 
5b41d95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64b8a91
5b41d95
 
 
e9cc5f1
5b41d95
 
 
 
 
 
 
 
 
 
 
 
 
64b8a91
 
 
5b41d95
 
 
 
 
 
 
 
64b8a91
 
 
 
5b41d95
 
64b8a91
312a2f2
64b8a91
 
 
5b41d95
 
 
64b8a91
5b41d95
 
 
 
 
 
 
 
64b8a91
5b41d95
 
 
64b8a91
 
 
 
 
5b41d95
 
64b8a91
5b41d95
64b8a91
 
5b41d95
64b8a91
 
5b41d95
 
 
 
 
 
 
 
 
 
64b8a91
5b41d95
 
 
 
64b8a91
5b41d95
 
 
64b8a91
 
 
 
 
5b41d95
64b8a91
 
5b41d95
64b8a91
5b41d95
64b8a91
5b41d95
 
 
 
d2c67f3
95ca446
5b41d95
64b8a91
1cddd79
5b41d95
 
b4f3ea6
b6ce847
5b41d95
27eab0f
 
 
 
5b41d95
 
33fd6ad
5b41d95
b4f3ea6
 
 
1cddd79
7b04d4e
1cddd79
 
64b8a91
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image as PILImage
import io
import os
import base64
import random

def create_monitor_interface():
    api_key = os.getenv("GROQ_API_KEY")
    
    class SafetyMonitor:
        def __init__(self):
            self.client = Groq()
            self.model_name = "llama-3.2-90b-vision-preview"
            self.max_image_size = (800, 800)  # Increased size for better visibility
            self.colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
            
        def resize_image(self, image):
            height, width = image.shape[:2]
            aspect = width / height
            
            if width > height:
                new_width = min(self.max_image_size[0], width)
                new_height = int(new_width / aspect)
            else:
                new_height = min(self.max_image_size[1], height)
                new_width = int(new_height * aspect)
                
            return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)

        def analyze_frame(self, frame: np.ndarray) -> str:
            if frame is None:
                return "No frame received"
                
            # Convert and resize image
            if len(frame.shape) == 2:
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
            elif len(frame.shape) == 3 and frame.shape[2] == 4:
                frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
            
            frame = self.resize_image(frame)
            frame_pil = PILImage.fromarray(frame)
            
            # Convert to base64 with minimal quality
            buffered = io.BytesIO()
            frame_pil.save(buffered, 
                         format="JPEG", 
                         quality=50, 
                         optimize=True)
            img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
            image_url = f"data:image/jpeg;base64,{img_base64}"
            
            try:
                completion = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": """Analyze this workplace image and describe each safety concern in this format:
                                    - <location>Description</location>
                                    Use one line per issue, starting with a dash and location in tags."""
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": image_url
                                    }
                                }
                            ]
                        },
                        {
                            "role": "assistant",
                            "content": ""
                        }
                    ],
                    temperature=0.1,
                    max_tokens=500,
                    top_p=1,
                    stream=False,
                    stop=None
                )
                return completion.choices[0].message.content
            except Exception as e:
                print(f"Detailed error: {str(e)}")
                return f"Analysis Error: {str(e)}"

        def draw_observations(self, image, observations):
            height, width = image.shape[:2]
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.5
            thickness = 2
            
            # Generate random positions for each observation
            for idx, obs in enumerate(observations):
                color = self.colors[idx % len(self.colors)]
                
                # Generate random box position
                box_width = width // 3
                box_height = height // 3
                x = random.randint(0, width - box_width)
                y = random.randint(0, height - box_height)
                
                # Draw rectangle
                cv2.rectangle(image, (x, y), (x + box_width, y + box_height), color, 2)
                
                # Add label with background
                label = obs[:40] + "..." if len(obs) > 40 else obs
                label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
                cv2.rectangle(image, (x, y - 20), (x + label_size[0], y), color, -1)
                cv2.putText(image, label, (x, y - 5), font, font_scale, (255, 255, 255), thickness)
            
            return image

        def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
            if frame is None:
                return None, "No image provided"
                
            analysis = self.analyze_frame(frame)
            display_frame = self.resize_image(frame.copy())
            
            # Parse observations from the analysis
            observations = []
            for line in analysis.split('\n'):
                line = line.strip()
                if line.startswith('-'):
                    # Extract text between <location> tags if present
                    if '<location>' in line and '</location>' in line:
                        start = line.find('<location>') + len('<location>')
                        end = line.find('</location>')
                        observation = line[end + len('</location>'):].strip()
                    else:
                        observation = line[1:].strip()  # Remove the dash
                    if observation:
                        observations.append(observation)
            
            # Draw observations on the image
            annotated_frame = self.draw_observations(display_frame, observations)
            
            return annotated_frame, analysis

    # Create the main interface
    monitor = SafetyMonitor()
    
    with gr.Blocks() as demo:
        gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
        
        with gr.Row():
            input_image = gr.Image(label="Upload Image")
            output_image = gr.Image(label="Annotated Results")
        
        analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
            
        def analyze_image(image):
            if image is None:
                return None, "No image provided"
            try:
                processed_frame, analysis = monitor.process_frame(image)
                return processed_frame, analysis
            except Exception as e:
                print(f"Processing error: {str(e)}")
                return None, f"Error processing image: {str(e)}"
            
        input_image.change(
            fn=analyze_image,
            inputs=input_image,
            outputs=[output_image, analysis_text]
        )

    return demo

demo = create_monitor_interface()
demo.launch()