eusholli commited on
Commit
6781da9
1 Parent(s): 7768db8

Changed to streamlit-webrtc object detection

Browse files
app.py CHANGED
@@ -1,146 +1,163 @@
1
- import os
2
- # os.environ['OPENCV_AVFOUNDATION_SKIP_AUTH'] = '1'
 
 
3
 
4
- import streamlit as st
 
 
 
 
 
5
  import cv2
6
  import numpy as np
7
- from transformers import pipeline
8
- from PIL import Image, ImageDraw
9
- from mtcnn import MTCNN
10
-
11
- # Initialize the Hugging Face pipeline for facial emotion detection
12
- emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression")
13
-
14
- # Initialize MTCNN for face detection
15
- mtcnn = MTCNN()
16
-
17
- # Function to analyze sentiment
18
- def analyze_sentiment(face):
19
- # Convert face to RGB
20
- rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
21
- # Convert the face to a PIL image
22
- pil_image = Image.fromarray(rgb_face)
23
- # Analyze sentiment using the Hugging Face pipeline
24
- results = emotion_pipeline(pil_image)
25
- # Get the dominant emotion
26
- dominant_emotion = max(results, key=lambda x: x['score'])['label']
27
- return dominant_emotion
28
-
29
- TEXT_SIZE = 3
30
-
31
- # Function to detect faces, analyze sentiment, and draw a red box around them
32
- def detect_and_draw_faces(frame):
33
- # Detect faces using MTCNN
34
- results = mtcnn.detect_faces(frame)
35
-
36
- # Draw on the frame
37
- for result in results:
38
- x, y, w, h = result['box']
39
- face = frame[y:y+h, x:x+w]
40
- sentiment = analyze_sentiment(face)
41
- cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 0, 255), 10) # Thicker red box
42
-
43
- # Calculate position for the text background and the text itself
44
- text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[0]
45
- text_x = x
46
- text_y = y - 10
47
- background_tl = (text_x, text_y - text_size[1])
48
- background_br = (text_x + text_size[0], text_y + 5)
49
-
50
- # Draw black rectangle as background
51
- cv2.rectangle(frame, background_tl, background_br, (0, 0, 0), cv2.FILLED)
52
- # Draw white text on top
53
- cv2.putText(frame, sentiment, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, (255, 255, 255), 2)
54
-
55
- return frame
56
-
57
- # Function to capture video from webcam
58
- def video_stream():
59
- video_capture = cv2.VideoCapture(0)
60
- if not video_capture.isOpened():
61
- st.error("Error: Could not open video capture device.")
62
- return
63
-
64
- while True:
65
- ret, frame = video_capture.read()
66
- if not ret:
67
- st.error("Error: Failed to read frame from video capture device.")
68
- break
69
- yield frame
70
-
71
- video_capture.release()
72
-
73
- # Streamlit UI
74
- st.markdown(
75
- """
76
- <style>
77
- .main {
78
- background-color: #FFFFFF;
79
- }
80
- .reportview-container .main .block-container{
81
- padding-top: 2rem;
82
- }
83
- h1 {
84
- color: #E60012;
85
- font-family: 'Arial Black', Gadget, sans-serif;
86
- }
87
- h2 {
88
- color: #E60012;
89
- font-family: 'Arial', sans-serif;
90
- }
91
- h3 {
92
- color: #333333;
93
- font-family: 'Arial', sans-serif;
94
- }
95
- .stButton button {
96
- background-color: #E60012;
97
- color: white;
98
- border-radius: 5px;
99
- font-size: 16px;
100
- }
101
- </style>
102
- """,
103
- unsafe_allow_html=True
104
- )
105
 
106
- st.title("Computer Vision Test Lab")
107
- st.subheader("Facial Sentiment")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # Columns for input and output streams
110
- col1, col2 = st.columns(2)
111
 
112
- with col1:
113
- st.header("Input Stream")
114
- st.subheader("Webcam")
115
- video_placeholder = st.empty()
 
116
 
117
- with col2:
118
- st.header("Output Stream")
119
- st.subheader("Analysis")
120
- output_placeholder = st.empty()
121
 
122
- sentiment_placeholder = st.empty()
 
 
123
 
124
- # Start video stream
125
- video_capture = cv2.VideoCapture(0)
126
- if not video_capture.isOpened():
127
- st.error("Error: Could not open video capture device.")
 
 
 
 
 
 
 
128
  else:
129
- while True:
130
- ret, frame = video_capture.read()
131
- if not ret:
132
- st.error("Error: Failed to read frame from video capture device.")
133
- break
134
-
135
- # Display the input stream with the red box around the face
136
- video_placeholder.image(frame, channels="BGR")
137
-
138
- # Detect faces, analyze sentiment, and draw red boxes with sentiment labels
139
- frame_with_boxes = detect_and_draw_faces(frame)
140
-
141
- # Display the output stream (here it's the same as input, modify as needed)
142
- output_placeholder.image(frame_with_boxes, channels="BGR")
143
-
144
- # Add a short delay to control the frame rate
145
- if cv2.waitKey(1) & 0xFF == ord('q'):
146
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Object detection demo with MobileNet SSD.
2
+ This model and code are based on
3
+ https://github.com/robmarkcole/object-detection-app
4
+ """
5
 
6
+ import logging
7
+ import queue
8
+ from pathlib import Path
9
+ from typing import List, NamedTuple
10
+
11
+ import av
12
  import cv2
13
  import numpy as np
14
+ import streamlit as st
15
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
16
+
17
+ from utils.download import download_file
18
+ from utils.turn import get_ice_servers
19
+
20
+ HERE = Path(__file__).parent
21
+ ROOT = HERE.parent
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
27
+ MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
28
+ PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
29
+ PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ CLASSES = [
32
+ "background",
33
+ "aeroplane",
34
+ "bicycle",
35
+ "bird",
36
+ "boat",
37
+ "bottle",
38
+ "bus",
39
+ "car",
40
+ "cat",
41
+ "chair",
42
+ "cow",
43
+ "diningtable",
44
+ "dog",
45
+ "horse",
46
+ "motorbike",
47
+ "person",
48
+ "pottedplant",
49
+ "sheep",
50
+ "sofa",
51
+ "train",
52
+ "tvmonitor",
53
+ ]
54
 
 
 
55
 
56
+ class Detection(NamedTuple):
57
+ class_id: int
58
+ label: str
59
+ score: float
60
+ box: np.ndarray
61
 
 
 
 
 
62
 
63
+ @st.cache_resource # type: ignore
64
+ def generate_label_colors():
65
+ return np.random.uniform(0, 255, size=(len(CLASSES), 3))
66
 
67
+
68
+ COLORS = generate_label_colors()
69
+
70
+ download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
71
+ download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
72
+
73
+
74
+ # Session-specific caching
75
+ cache_key = "object_detection_dnn"
76
+ if cache_key in st.session_state:
77
+ net = st.session_state[cache_key]
78
  else:
79
+ net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
80
+ st.session_state[cache_key] = net
81
+
82
+ score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
83
+
84
+ # NOTE: The callback will be called in another thread,
85
+ # so use a queue here for thread-safety to pass the data
86
+ # from inside to outside the callback.
87
+ # TODO: A general-purpose shared state object may be more useful.
88
+ result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
89
+
90
+
91
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
92
+ image = frame.to_ndarray(format="bgr24")
93
+
94
+ # Run inference
95
+ blob = cv2.dnn.blobFromImage(
96
+ cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
97
+ )
98
+ net.setInput(blob)
99
+ output = net.forward()
100
+
101
+ h, w = image.shape[:2]
102
+
103
+ # Convert the output array into a structured form.
104
+ output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
105
+ output = output[output[:, 2] >= score_threshold]
106
+ detections = [
107
+ Detection(
108
+ class_id=int(detection[1]),
109
+ label=CLASSES[int(detection[1])],
110
+ score=float(detection[2]),
111
+ box=(detection[3:7] * np.array([w, h, w, h])),
112
+ )
113
+ for detection in output
114
+ ]
115
+
116
+ # Render bounding boxes and captions
117
+ for detection in detections:
118
+ caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
119
+ color = COLORS[detection.class_id]
120
+ xmin, ymin, xmax, ymax = detection.box.astype("int")
121
+
122
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
123
+ cv2.putText(
124
+ image,
125
+ caption,
126
+ (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
127
+ cv2.FONT_HERSHEY_SIMPLEX,
128
+ 0.5,
129
+ color,
130
+ 2,
131
+ )
132
+
133
+ result_queue.put(detections)
134
+
135
+ return av.VideoFrame.from_ndarray(image, format="bgr24")
136
+
137
+
138
+ webrtc_ctx = webrtc_streamer(
139
+ key="object-detection",
140
+ mode=WebRtcMode.SENDRECV,
141
+ rtc_configuration={"iceServers": get_ice_servers()},
142
+ video_frame_callback=video_frame_callback,
143
+ media_stream_constraints={"video": True, "audio": False},
144
+ async_processing=True,
145
+ )
146
+
147
+ if st.checkbox("Show the detected labels", value=True):
148
+ if webrtc_ctx.state.playing:
149
+ labels_placeholder = st.empty()
150
+ # NOTE: The video transformation with object detection and
151
+ # this loop displaying the result labels are running
152
+ # in different threads asynchronously.
153
+ # Then the rendered video frames and the labels displayed here
154
+ # are not strictly synchronized.
155
+ while True:
156
+ result = result_queue.get()
157
+ labels_placeholder.table(result)
158
+
159
+ st.markdown(
160
+ "This demo uses a model and code from "
161
+ "https://github.com/robmarkcole/object-detection-app. "
162
+ "Many thanks to the project."
163
+ )
app.py.sentiment-one DELETED
@@ -1,118 +0,0 @@
1
- import os
2
- os.environ['OPENCV_AVFOUNDATION_SKIP_AUTH'] = '1'
3
-
4
- import streamlit as st
5
- import cv2
6
- from transformers import pipeline
7
- from PIL import Image
8
-
9
- # Initialize the Hugging Face pipeline for facial emotion detection
10
- emotion_pipeline = pipeline("image-classification", model="dima806/facial_emotions_image_detection")
11
-
12
- # Function to analyze sentiment
13
- def analyze_sentiment(frame):
14
- # Convert frame to RGB
15
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
16
- # Convert the frame to a PIL image
17
- pil_image = Image.fromarray(rgb_frame)
18
- # Analyze sentiment using the Hugging Face pipeline
19
- results = emotion_pipeline(pil_image) # Analyze sentiment using the Hugging Face pipeline
20
- results = emotion_pipeline(pil_image)
21
- # Get the dominant emotion
22
- dominant_emotion = max(results, key=lambda x: x['score'])['label']
23
- return dominant_emotion
24
-
25
- # Function to capture video from webcam
26
- def video_stream():
27
- video_capture = cv2.VideoCapture(0)
28
- if not video_capture.isOpened():
29
- st.error("Error: Could not open video capture device.")
30
- return
31
-
32
- while True:
33
- ret, frame = video_capture.read()
34
- if not ret:
35
- st.error("Error: Failed to read frame from video capture device.")
36
- break
37
- yield frame
38
-
39
- video_capture.release()
40
-
41
- # Streamlit UI
42
- st.markdown(
43
- """
44
- <style>
45
- .main {
46
- background-color: #FFFFFF;
47
- }
48
- .reportview-container .main .block-container{
49
- padding-top: 2rem;
50
- }
51
- h1 {
52
- color: #E60012;
53
- font-family: 'Arial Black', Gadget, sans-serif;
54
- }
55
- h2 {
56
- color: #E60012;
57
- font-family: 'Arial', sans-serif;
58
- }
59
- h3 {
60
- color: #333333;
61
- font-family: 'Arial', sans-serif;
62
- }
63
- .stButton button {
64
- background-color: #E60012;
65
- color: white;
66
- border-radius: 5px;
67
- font-size: 16px;
68
- }
69
- </style>
70
- """,
71
- unsafe_allow_html=True
72
- )
73
-
74
- st.title("Computer Vision Test Lab")
75
- st.subheader("Facial Sentiment")
76
-
77
- # Columns for input and output streams
78
- col1, col2 = st.columns(2)
79
-
80
- with col1:
81
- st.header("Input Stream")
82
- st.subheader("Webcam")
83
- video_placeholder = st.empty()
84
-
85
- with col2:
86
- st.header("Output Stream")
87
- st.subheader("Analysis")
88
- output_placeholder = st.empty()
89
-
90
- sentiment_placeholder = st.empty()
91
-
92
- # Start video stream
93
- video_capture = cv2.VideoCapture(0)
94
- if not video_capture.isOpened():
95
- st.error("Error: Could not open video capture device.")
96
- else:
97
- while True:
98
- ret, frame = video_capture.read()
99
- if not ret:
100
- st.error("Error: Failed to read frame from video capture device.")
101
- break
102
-
103
- # Display the input stream
104
- video_placeholder.image(frame, channels="BGR")
105
-
106
- # Analyze sentiment
107
- sentiment = analyze_sentiment(frame)
108
-
109
- # Display the output stream (here it's the same as input, modify as needed)
110
- output_placeholder.image(frame, channels="BGR")
111
-
112
- # Display sentiment
113
- sentiment_placeholder.write(f"Sentiment: {sentiment}")
114
-
115
- # Add a short delay to control the frame rate
116
- if cv2.waitKey(1) & 0xFF == ord('q'):
117
- break
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
object_detection.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Object detection demo with MobileNet SSD.
2
+ This model and code are based on
3
+ https://github.com/robmarkcole/object-detection-app
4
+ """
5
+
6
+ import logging
7
+ import queue
8
+ from pathlib import Path
9
+ from typing import List, NamedTuple
10
+
11
+ import av
12
+ import cv2
13
+ import numpy as np
14
+ import streamlit as st
15
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
16
+
17
+ from utils.download import download_file
18
+ from utils.turn import get_ice_servers
19
+
20
+ HERE = Path(__file__).parent
21
+ ROOT = HERE.parent
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
27
+ MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
28
+ PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
29
+ PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
30
+
31
+ CLASSES = [
32
+ "background",
33
+ "aeroplane",
34
+ "bicycle",
35
+ "bird",
36
+ "boat",
37
+ "bottle",
38
+ "bus",
39
+ "car",
40
+ "cat",
41
+ "chair",
42
+ "cow",
43
+ "diningtable",
44
+ "dog",
45
+ "horse",
46
+ "motorbike",
47
+ "person",
48
+ "pottedplant",
49
+ "sheep",
50
+ "sofa",
51
+ "train",
52
+ "tvmonitor",
53
+ ]
54
+
55
+
56
+ class Detection(NamedTuple):
57
+ class_id: int
58
+ label: str
59
+ score: float
60
+ box: np.ndarray
61
+
62
+
63
+ @st.cache_resource # type: ignore
64
+ def generate_label_colors():
65
+ return np.random.uniform(0, 255, size=(len(CLASSES), 3))
66
+
67
+
68
+ COLORS = generate_label_colors()
69
+
70
+ download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
71
+ download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
72
+
73
+
74
+ # Session-specific caching
75
+ cache_key = "object_detection_dnn"
76
+ if cache_key in st.session_state:
77
+ net = st.session_state[cache_key]
78
+ else:
79
+ net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
80
+ st.session_state[cache_key] = net
81
+
82
+ score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
83
+
84
+ # NOTE: The callback will be called in another thread,
85
+ # so use a queue here for thread-safety to pass the data
86
+ # from inside to outside the callback.
87
+ # TODO: A general-purpose shared state object may be more useful.
88
+ result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
89
+
90
+
91
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
92
+ image = frame.to_ndarray(format="bgr24")
93
+
94
+ # Run inference
95
+ blob = cv2.dnn.blobFromImage(
96
+ cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
97
+ )
98
+ net.setInput(blob)
99
+ output = net.forward()
100
+
101
+ h, w = image.shape[:2]
102
+
103
+ # Convert the output array into a structured form.
104
+ output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
105
+ output = output[output[:, 2] >= score_threshold]
106
+ detections = [
107
+ Detection(
108
+ class_id=int(detection[1]),
109
+ label=CLASSES[int(detection[1])],
110
+ score=float(detection[2]),
111
+ box=(detection[3:7] * np.array([w, h, w, h])),
112
+ )
113
+ for detection in output
114
+ ]
115
+
116
+ # Render bounding boxes and captions
117
+ for detection in detections:
118
+ caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
119
+ color = COLORS[detection.class_id]
120
+ xmin, ymin, xmax, ymax = detection.box.astype("int")
121
+
122
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
123
+ cv2.putText(
124
+ image,
125
+ caption,
126
+ (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
127
+ cv2.FONT_HERSHEY_SIMPLEX,
128
+ 0.5,
129
+ color,
130
+ 2,
131
+ )
132
+
133
+ result_queue.put(detections)
134
+
135
+ return av.VideoFrame.from_ndarray(image, format="bgr24")
136
+
137
+
138
+ webrtc_ctx = webrtc_streamer(
139
+ key="object-detection",
140
+ mode=WebRtcMode.SENDRECV,
141
+ rtc_configuration={"iceServers": get_ice_servers()},
142
+ video_frame_callback=video_frame_callback,
143
+ media_stream_constraints={"video": True, "audio": False},
144
+ async_processing=True,
145
+ )
146
+
147
+ if st.checkbox("Show the detected labels", value=True):
148
+ if webrtc_ctx.state.playing:
149
+ labels_placeholder = st.empty()
150
+ # NOTE: The video transformation with object detection and
151
+ # this loop displaying the result labels are running
152
+ # in different threads asynchronously.
153
+ # Then the rendered video frames and the labels displayed here
154
+ # are not strictly synchronized.
155
+ while True:
156
+ result = result_queue.get()
157
+ labels_placeholder.table(result)
158
+
159
+ st.markdown(
160
+ "This demo uses a model and code from "
161
+ "https://github.com/robmarkcole/object-detection-app. "
162
+ "Many thanks to the project."
163
+ )
requirements.txt CHANGED
@@ -7,3 +7,4 @@ mtcnn
7
  setuptools
8
  tensorflow
9
  tf-keras
 
 
7
  setuptools
8
  tensorflow
9
  tf-keras
10
+ streamlit_webrtc
run_streamlist.sh DELETED
@@ -1,5 +0,0 @@
1
- #!/bin/bash
2
- # Set Chrome as the default browser for this session
3
- export BROWSER="/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
4
- # Run Streamlit with the provided arguments
5
- streamlit run "$@"
 
 
 
 
 
 
app.py.safe → sentiment.py RENAMED
@@ -1,15 +1,27 @@
1
- import os
2
- os.environ['OPENCV_AVFOUNDATION_SKIP_AUTH'] = '1'
3
 
4
  import streamlit as st
5
  import cv2
6
  import numpy as np
7
  from transformers import pipeline
8
  from PIL import Image, ImageDraw
 
 
 
9
 
10
- # Initialize the Hugging Face pipeline for facial emotion detection using the "trpakov/vit-face-expression" model
 
 
 
 
 
 
 
11
  emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression")
12
 
 
 
 
13
  # Function to analyze sentiment
14
  def analyze_sentiment(face):
15
  # Convert face to RGB
@@ -26,58 +38,29 @@ TEXT_SIZE = 3
26
 
27
  # Function to detect faces, analyze sentiment, and draw a red box around them
28
  def detect_and_draw_faces(frame):
29
- # Convert frame to RGB
30
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
31
- # Convert the frame to a PIL image
32
- pil_image = Image.fromarray(rgb_frame)
33
- # Analyze sentiment using the Hugging Face pipeline
34
- results = emotion_pipeline(pil_image)
35
-
36
- # Print the results to understand the structure
37
- print(results)
38
 
39
- # Draw on the PIL image
40
- draw = ImageDraw.Draw(pil_image)
41
-
42
- # Iterate through detected faces
43
  for result in results:
44
- box = result['box']
45
- sentiment = result['label']
46
-
47
- # Draw rectangle and text
48
- x, y, w, h = box['left'], box['top'], box['width'], box['height']
49
- draw.rectangle(((x, y), (x+w, y+h)), outline="red", width=3)
50
 
51
  # Calculate position for the text background and the text itself
52
- text_size = draw.textsize(sentiment)
53
- background_tl = (x, y - text_size[1] - 5)
54
- background_br = (x + text_size[0], y)
 
 
55
 
56
  # Draw black rectangle as background
57
- draw.rectangle([background_tl, background_br], fill="black")
58
  # Draw white text on top
59
- draw.text((x, y - text_size[1]), sentiment, fill="white")
60
-
61
- # Convert back to OpenCV format
62
- frame_with_boxes = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
63
 
64
- return frame_with_boxes
65
-
66
- # Function to capture video from webcam
67
- def video_stream():
68
- video_capture = cv2.VideoCapture(0)
69
- if not video_capture.isOpened():
70
- st.error("Error: Could not open video capture device.")
71
- return
72
-
73
- while True:
74
- ret, frame = video_capture.read()
75
- if not ret:
76
- st.error("Error: Failed to read frame from video capture device.")
77
- break
78
- yield frame
79
-
80
- video_capture.release()
81
 
82
  # Streamlit UI
83
  st.markdown(
@@ -130,26 +113,29 @@ with col2:
130
 
131
  sentiment_placeholder = st.empty()
132
 
133
- # Start video stream
134
- video_capture = cv2.VideoCapture(0)
135
- if not video_capture.isOpened():
136
- st.error("Error: Could not open video capture device.")
137
- else:
138
- while True:
139
- ret, frame = video_capture.read()
140
- if not ret:
141
- st.error("Error: Failed to read frame from video capture device.")
142
- break
143
-
144
- # Detect faces, analyze sentiment, and draw red boxes with sentiment labels
145
- frame_with_boxes = detect_and_draw_faces(frame)
146
-
147
- # Display the input stream with the red box around the face
148
- video_placeholder.image(frame_with_boxes, channels="BGR")
149
-
150
- # Display the output stream (here it's the same as input, modify as needed)
151
- output_placeholder.image(frame_with_boxes, channels="BGR")
152
-
153
- # Add a short delay to control the frame rate
154
- if cv2.waitKey(1) & 0xFF == ord('q'):
155
- break
 
 
 
 
1
+ import threading
 
2
 
3
  import streamlit as st
4
  import cv2
5
  import numpy as np
6
  from transformers import pipeline
7
  from PIL import Image, ImageDraw
8
+ from mtcnn import MTCNN
9
+ from streamlit_webrtc import webrtc_streamer
10
+ import logging
11
 
12
+ # Suppress transformers progress bars
13
+ logging.getLogger("transformers").setLevel(logging.ERROR)
14
+
15
+ lock = threading.Lock()
16
+ img_container = {"webcam": None,
17
+ "analzyed": None}
18
+
19
+ # Initialize the Hugging Face pipeline for facial emotion detection
20
  emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression")
21
 
22
+ # Initialize MTCNN for face detection
23
+ mtcnn = MTCNN()
24
+
25
  # Function to analyze sentiment
26
  def analyze_sentiment(face):
27
  # Convert face to RGB
 
38
 
39
  # Function to detect faces, analyze sentiment, and draw a red box around them
40
  def detect_and_draw_faces(frame):
41
+ # Detect faces using MTCNN
42
+ results = mtcnn.detect_faces(frame)
 
 
 
 
 
 
 
43
 
44
+ # Draw on the frame
 
 
 
45
  for result in results:
46
+ x, y, w, h = result['box']
47
+ face = frame[y:y+h, x:x+w]
48
+ sentiment = analyze_sentiment(face)
49
+ cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 0, 255), 10) # Thicker red box
 
 
50
 
51
  # Calculate position for the text background and the text itself
52
+ text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[0]
53
+ text_x = x
54
+ text_y = y - 10
55
+ background_tl = (text_x, text_y - text_size[1])
56
+ background_br = (text_x + text_size[0], text_y + 5)
57
 
58
  # Draw black rectangle as background
59
+ cv2.rectangle(frame, background_tl, background_br, (0, 0, 0), cv2.FILLED)
60
  # Draw white text on top
61
+ cv2.putText(frame, sentiment, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, (255, 255, 255), 2)
 
 
 
62
 
63
+ return frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # Streamlit UI
66
  st.markdown(
 
113
 
114
  sentiment_placeholder = st.empty()
115
 
116
+ def video_frame_callback(frame):
117
+ try:
118
+ with lock:
119
+ img = frame.to_ndarray(format="bgr24")
120
+ img_container["webcam"] = img
121
+ frame_with_boxes = detect_and_draw_faces(img)
122
+ img_container["analyzed"] = frame_with_boxes
123
+
124
+ except Exception as e:
125
+ st.error(f"Error processing frame: {e}")
126
+
127
+ return frame
128
+
129
+ ctx = webrtc_streamer(key="webcam", video_frame_callback=video_frame_callback)
130
+
131
+ while ctx.state.playing:
132
+ with lock:
133
+ print(img_container)
134
+ img = img_container["webcam"]
135
+ frame_with_boxes = img_container["analyzed"]
136
+
137
+ if img is None:
138
+ continue
139
+
140
+ video_placeholder.image(img, channels="BGR")
141
+ output_placeholder.image(frame_with_boxes, channels="BGR")
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
utils/download.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from pathlib import Path
3
+
4
+ import streamlit as st
5
+
6
+
7
+ # This code is based on https://github.com/streamlit/demo-self-driving/blob/230245391f2dda0cb464008195a470751c01770b/streamlit_app.py#L48 # noqa: E501
8
+ def download_file(url, download_to: Path, expected_size=None):
9
+ # Don't download the file twice.
10
+ # (If possible, verify the download using the file length.)
11
+ if download_to.exists():
12
+ if expected_size:
13
+ if download_to.stat().st_size == expected_size:
14
+ return
15
+ else:
16
+ st.info(f"{url} is already downloaded.")
17
+ if not st.button("Download again?"):
18
+ return
19
+
20
+ download_to.parent.mkdir(parents=True, exist_ok=True)
21
+
22
+ # These are handles to two visual elements to animate.
23
+ weights_warning, progress_bar = None, None
24
+ try:
25
+ weights_warning = st.warning("Downloading %s..." % url)
26
+ progress_bar = st.progress(0)
27
+ with open(download_to, "wb") as output_file:
28
+ with urllib.request.urlopen(url) as response:
29
+ length = int(response.info()["Content-Length"])
30
+ counter = 0.0
31
+ MEGABYTES = 2.0**20.0
32
+ while True:
33
+ data = response.read(8192)
34
+ if not data:
35
+ break
36
+ counter += len(data)
37
+ output_file.write(data)
38
+
39
+ # We perform animation by overwriting the elements.
40
+ weights_warning.warning(
41
+ "Downloading %s... (%6.2f/%6.2f MB)"
42
+ % (url, counter / MEGABYTES, length / MEGABYTES)
43
+ )
44
+ progress_bar.progress(min(counter / length, 1.0))
45
+ # Finally, we remove these visual elements by calling .empty().
46
+ finally:
47
+ if weights_warning is not None:
48
+ weights_warning.empty()
49
+ if progress_bar is not None:
50
+ progress_bar.empty()
utils/turn.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import streamlit as st
5
+ from twilio.base.exceptions import TwilioRestException
6
+ from twilio.rest import Client
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def get_ice_servers():
12
+ """Use Twilio's TURN server because Streamlit Community Cloud has changed
13
+ its infrastructure and WebRTC connection cannot be established without TURN server now. # noqa: E501
14
+ We considered Open Relay Project (https://www.metered.ca/tools/openrelay/) too,
15
+ but it is not stable and hardly works as some people reported like https://github.com/aiortc/aiortc/issues/832#issuecomment-1482420656 # noqa: E501
16
+ See https://github.com/whitphx/streamlit-webrtc/issues/1213
17
+ """
18
+
19
+ # Ref: https://www.twilio.com/docs/stun-turn/api
20
+ try:
21
+ account_sid = os.environ["TWILIO_ACCOUNT_SID"]
22
+ auth_token = os.environ["TWILIO_AUTH_TOKEN"]
23
+ except KeyError:
24
+ logger.warning(
25
+ "Twilio credentials are not set. Fallback to a free STUN server from Google." # noqa: E501
26
+ )
27
+ return [{"urls": ["stun:stun.l.google.com:19302"]}]
28
+
29
+ client = Client(account_sid, auth_token)
30
+
31
+ try:
32
+ token = client.tokens.create()
33
+ except TwilioRestException as e:
34
+ st.warning(
35
+ f"Error occurred while accessing Twilio API. Fallback to a free STUN server from Google. ({e})" # noqa: E501
36
+ )
37
+ return [{"urls": ["stun:stun.l.google.com:19302"]}]
38
+
39
+ return token.ice_servers