capradeepgujaran commited on
Commit
33fd6ad
·
verified ·
1 Parent(s): 670756a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -97
app.py CHANGED
@@ -5,27 +5,33 @@ from groq import Groq
5
  import time
6
  from PIL import Image
7
  import io
8
- from typing import Optional
 
 
 
 
9
 
10
  class SafetyMonitor:
11
- def __init__(self, api_key: str, model_name: str = "mixtral-8x7b-vision"):
12
  """
13
- Initialize the safety monitor with configurable model
14
-
15
- Args:
16
- api_key (str): Groq API key
17
- model_name (str): Name of the vision model to use
18
  """
 
 
 
 
19
  self.client = Groq(api_key=api_key)
20
  self.model_name = model_name
21
- self.analysis_interval = 2 # seconds
22
 
23
  def analyze_frame(self, frame: np.ndarray) -> str:
24
  """
25
  Analyze a single frame using specified vision model
26
  """
 
 
 
27
  # Convert frame to PIL Image
28
- frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
29
 
30
  # Convert image to bytes
31
  img_byte_arr = io.BytesIO()
@@ -38,7 +44,6 @@ class SafetyMonitor:
38
  2. Unsafe behaviors or positions
39
  3. Equipment and machinery safety
40
  4. Environmental hazards (spills, obstacles, poor lighting)
41
- 5. Emergency exit accessibility
42
 
43
  Provide specific observations and any immediate safety concerns."""
44
 
@@ -55,107 +60,83 @@ class SafetyMonitor:
55
  ],
56
  model=self.model_name,
57
  max_tokens=200,
58
- temperature=0.2 # Lower temperature for more focused safety analysis
59
  )
60
  return completion.choices[0].message.content
61
  except Exception as e:
62
  return f"Analysis Error: {str(e)}"
63
 
64
- def process_video_stream(self):
65
  """
66
- Process video stream and yield analyzed frames
67
  """
68
- cap = cv2.VideoCapture(0) # Use 0 for webcam
69
- last_analysis_time = 0
70
- latest_analysis = "Initializing safety analysis..."
71
-
72
- while cap.isOpened():
73
- ret, frame = cap.read()
74
- if not ret:
75
- break
76
-
77
- current_time = time.time()
78
 
79
- # Perform analysis at specified intervals
80
- if current_time - last_analysis_time >= self.analysis_interval:
81
- latest_analysis = self.analyze_frame(frame)
82
- last_analysis_time = current_time
83
-
84
- # Create a copy of frame for visualization
85
- display_frame = frame.copy()
86
-
87
- # Add semi-transparent overlay for text background
88
- overlay = display_frame.copy()
89
- cv2.rectangle(overlay, (5, 5), (640, 200), (0, 0, 0), -1)
90
- cv2.addWeighted(overlay, 0.3, display_frame, 0.7, 0, display_frame)
91
-
92
- # Add analysis text
93
- cv2.putText(display_frame, "Safety Analysis:", (10, 30),
94
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
95
-
96
- # Split and display analysis text
97
- y_position = 60
98
- for line in latest_analysis.split('\n'):
99
- cv2.putText(display_frame, line[:80], (10, y_position),
100
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
101
- y_position += 30
102
-
103
- yield display_frame
104
 
105
- cap.release()
106
 
107
- def create_gradio_interface(monitor: SafetyMonitor):
108
  """
109
- Create and launch the Gradio interface
110
  """
111
- with gr.Blocks() as demo:
112
- gr.Markdown(f"""
113
- # Real-time Safety Monitoring System
114
- Using model: {monitor.model_name}
115
- """)
116
-
117
- with gr.Row():
118
- video_output = gr.Image(label="Live Feed with Safety Analysis")
119
 
120
- with gr.Row():
121
- start_button = gr.Button("Start Monitoring", variant="primary")
122
- stop_button = gr.Button("Stop")
 
 
123
 
124
- with gr.Row():
125
- interval_slider = gr.Slider(
126
- minimum=1,
127
- maximum=10,
128
- value=monitor.analysis_interval,
129
- step=0.5,
130
- label="Analysis Interval (seconds)"
 
 
 
 
 
 
 
 
 
 
 
 
131
  )
132
 
133
- def update_interval(value):
134
- monitor.analysis_interval = value
135
- return gr.update()
136
-
137
- def start_monitoring():
138
- return gr.Image.update(value=monitor.process_video_stream())
139
-
140
- start_button.click(fn=start_monitoring, outputs=[video_output])
141
- stop_button.click(fn=lambda: None, outputs=[video_output])
142
- interval_slider.change(fn=update_interval, inputs=[interval_slider])
143
-
144
- demo.launch(share=True)
145
-
146
- def main():
147
- # Replace with your actual API key
148
- GROQ_API_KEY = "YOUR_GROQ_API_KEY"
149
-
150
- # Initialize the safety monitor with desired model
151
- monitor = SafetyMonitor(
152
- api_key=GROQ_API_KEY,
153
- model_name="mixtral-8x7b-vision" # Replace with your preferred model
154
- )
155
-
156
- # Launch the Gradio interface
157
- create_gradio_interface(monitor)
158
-
159
  if __name__ == "__main__":
160
- main()
161
-
 
5
  import time
6
  from PIL import Image
7
  import io
8
+ import os
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
 
14
  class SafetyMonitor:
15
+ def __init__(self, model_name: str = "mixtral-8x7b-vision"):
16
  """
17
+ Initialize the safety monitor using environment variables for API key
 
 
 
 
18
  """
19
+ api_key = os.getenv("GROQ_API_KEY")
20
+ if not api_key:
21
+ raise ValueError("GROQ_API_KEY environment variable is not set")
22
+
23
  self.client = Groq(api_key=api_key)
24
  self.model_name = model_name
 
25
 
26
  def analyze_frame(self, frame: np.ndarray) -> str:
27
  """
28
  Analyze a single frame using specified vision model
29
  """
30
+ if frame is None:
31
+ return "No frame received"
32
+
33
  # Convert frame to PIL Image
34
+ frame_pil = Image.fromarray(frame)
35
 
36
  # Convert image to bytes
37
  img_byte_arr = io.BytesIO()
 
44
  2. Unsafe behaviors or positions
45
  3. Equipment and machinery safety
46
  4. Environmental hazards (spills, obstacles, poor lighting)
 
47
 
48
  Provide specific observations and any immediate safety concerns."""
49
 
 
60
  ],
61
  model=self.model_name,
62
  max_tokens=200,
63
+ temperature=0.2
64
  )
65
  return completion.choices[0].message.content
66
  except Exception as e:
67
  return f"Analysis Error: {str(e)}"
68
 
69
+ def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
70
  """
71
+ Process and analyze a single frame
72
  """
73
+ if frame is None:
74
+ return None, "No frame received"
 
 
 
 
 
 
 
 
75
 
76
+ analysis = self.analyze_frame(frame)
77
+
78
+ # Create a copy of frame for visualization
79
+ display_frame = frame.copy()
80
+
81
+ # Add semi-transparent overlay for text background
82
+ overlay = display_frame.copy()
83
+ cv2.rectangle(overlay, (5, 5), (640, 200), (0, 0, 0), -1)
84
+ cv2.addWeighted(overlay, 0.3, display_frame, 0.7, 0, display_frame)
85
+
86
+ # Add analysis text
87
+ cv2.putText(display_frame, "Safety Analysis:", (10, 30),
88
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
89
+
90
+ # Split and display analysis text
91
+ y_position = 60
92
+ for line in analysis.split('\n'):
93
+ cv2.putText(display_frame, line[:80], (10, y_position),
94
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
95
+ y_position += 30
 
 
 
 
 
96
 
97
+ return display_frame, analysis
98
 
99
+ def create_gradio_interface():
100
  """
101
+ Create and launch the Gradio interface with webcam input
102
  """
103
+ try:
104
+ # Initialize the safety monitor
105
+ monitor = SafetyMonitor(model_name="mixtral-8x7b-vision")
 
 
 
 
 
106
 
107
+ with gr.Blocks() as demo:
108
+ gr.Markdown("""
109
+ # Real-time Safety Monitoring System
110
+ Click 'Start Webcam' to begin monitoring.
111
+ """)
112
 
113
+ with gr.Row():
114
+ # Webcam input
115
+ webcam = gr.Image(source="webcam", streaming=True, label="Webcam Feed")
116
+ # Analysis output
117
+ output_image = gr.Image(label="Analyzed Feed")
118
+
119
+ with gr.Row():
120
+ analysis_text = gr.Textbox(label="Safety Analysis", lines=5)
121
+
122
+ def analyze_stream(frame):
123
+ if frame is None:
124
+ return None, "Webcam not started"
125
+ processed_frame, analysis = monitor.process_frame(frame)
126
+ return processed_frame, analysis
127
+
128
+ webcam.stream(
129
+ fn=analyze_stream,
130
+ outputs=[output_image, analysis_text],
131
+ show_progress="hidden"
132
  )
133
 
134
+ demo.queue()
135
+ demo.launch()
136
+
137
+ except ValueError as e:
138
+ print(f"Error: {e}")
139
+ print("Please make sure to set the GROQ_API_KEY environment variable")
140
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  if __name__ == "__main__":
142
+ create_gradio_interface()