|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image,ImageDraw, ImageFont |
|
from transformers import pipeline |
|
import torch |
|
from random import choice |
|
import os |
|
from datetime import datetime |
|
|
|
|
|
detector = pipeline(model="facebook/detr-resnet-101", use_fast=True) |
|
if torch.cuda.is_available(): |
|
detector.model.to('cuda') |
|
|
|
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff", |
|
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf", |
|
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"] |
|
|
|
|
|
|
|
threshold = 90 |
|
label_list = ["person", "car", "truck"] |
|
|
|
label_color_dict = {} |
|
|
|
def query_data(in_pil_img: Image.Image): |
|
results = detector(in_pil_img) |
|
print(f"检测结果:{results}") |
|
return results |
|
|
|
|
|
def get_font_size(box_width, min_size=10, max_size=48): |
|
"""根据边界框宽度计算合适的字体大小""" |
|
|
|
font_size = max(24,int(box_width / 10)) |
|
return max(min(font_size, max_size), min_size) |
|
|
|
def get_text_position(box, text_bbox): |
|
"""根据边界框和文本边界框返回适当的位置""" |
|
xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax'] |
|
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] |
|
|
|
|
|
if ymin - text_height >= 0: |
|
return (xmin, ymin - text_height) |
|
else: |
|
return (xmin, ymin) |
|
|
|
def get_annotated_image(in_pil_img): |
|
draw = ImageDraw.Draw(in_pil_img) |
|
in_results = query_data(in_pil_img) |
|
|
|
for prediction in in_results: |
|
box = prediction['box'] |
|
label = prediction['label'] |
|
score = round(prediction['score'] * 100, 1) |
|
if score < threshold: |
|
continue |
|
if label not in label_list: |
|
continue |
|
|
|
if label not in label_color_dict: |
|
color = choice(COLORS) |
|
label_color_dict[label] = color |
|
else: |
|
color = label_color_dict[label] |
|
|
|
|
|
box_width = box['xmax'] - box['xmin'] |
|
font_size = get_font_size(box_width) |
|
font = ImageFont.truetype(font="arial.ttf", size=font_size) |
|
|
|
|
|
text = f"{label}: {score}%" |
|
text_bbox = draw.textbbox((0, 0), text, font=font) |
|
|
|
|
|
draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline=color, width=3) |
|
|
|
|
|
text_pos = get_text_position(box, text_bbox) |
|
draw.text(text_pos, text, fill=color, font=font) |
|
|
|
|
|
return np.array(in_pil_img.convert('RGB')) |
|
|
|
|
|
def process_video(input_video_path): |
|
cap = cv2.VideoCapture(input_video_path) |
|
if not cap.isOpened(): |
|
raise ValueError("无法打开输入视频文件") |
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
output_dir = './output_videos' |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_video_filename = f"output_{timestamp}.mp4" |
|
output_video_path = os.path.join(output_dir, output_video_filename) |
|
print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps") |
|
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
pil_image = Image.fromarray(rgb_frame) |
|
print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") |
|
annotated_frame = get_annotated_image(pil_image) |
|
bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) |
|
print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") |
|
|
|
if bgr_frame.shape[:2] != (height, width): |
|
bgr_frame = cv2.resize(bgr_frame, (width, height)) |
|
|
|
print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") |
|
out.write(bgr_frame) |
|
|
|
cap.release() |
|
out.release() |
|
|
|
|
|
return output_video_path |
|
|
|
def change_threshold(value): |
|
global threshold |
|
threshold = value |
|
return f"当前置信度阈值为{threshold}%" |
|
|
|
def update_labels(selected_labels): |
|
|
|
global label_list |
|
label_list = selected_labels |
|
return selected_labels |
|
|
|
with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo: |
|
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>") |
|
|
|
|
|
threshold_slider = gr.Slider(minimum=0, maximum=100, value=threshold, step=1, label="置信度阈值") |
|
textbox = gr.Textbox(value=f"当前置信度阈值为{threshold}%", label="置信度显示") |
|
|
|
threshold_slider.change(fn=change_threshold, inputs=[threshold_slider], outputs=[textbox]) |
|
|
|
|
|
label_checkboxes = gr.CheckboxGroup(choices=label_list, value=label_list, label="检测目标") |
|
|
|
label_checkboxes.change(fn=update_labels, inputs=[label_checkboxes], outputs=[label_checkboxes]) |
|
|
|
with gr.Row(): |
|
input_video = gr.Video(label="输入视频") |
|
detect_button = gr.Button("开始检测", variant="primary") |
|
output_video = gr.Video(label="输出视频") |
|
|
|
|
|
detect_button.click(process_video, inputs=input_video, outputs=output_video) |
|
|
|
demo.launch() |