File size: 6,542 Bytes
f93a294 fd08bcf 3123f9f 463eb87 3123f9f e1fe61e 3123f9f 463eb87 3123f9f 463eb87 46dad61 fbb1f4d 85114f8 3123f9f 46dad61 f93a294 fd08bcf 3123f9f 85114f8 3123f9f f93a294 463eb87 3123f9f fbb1f4d fd08bcf 46dad61 85114f8 fd08bcf 85114f8 463eb87 fd08bcf 00c0721 fd08bcf fbb1f4d fd08bcf 85114f8 fd08bcf 85114f8 fd08bcf 463eb87 85114f8 f93a294 fd08bcf 3123f9f f93a294 3123f9f f93a294 3123f9f 46dad61 3123f9f f93a294 3123f9f 46dad61 3123f9f 46dad61 3123f9f f93a294 46dad61 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 46dad61 3123f9f eec4598 f93a294 46dad61 463eb87 f93a294 3123f9f 463eb87 9e3cf2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import gradio as gr
import cv2
import numpy as np
from PIL import Image,ImageDraw, ImageFont
from transformers import pipeline
import torch
from random import choice
import os
from datetime import datetime
# 初始化对象检测器并移动到GPU(如果可用)
detector = pipeline(model="facebook/detr-resnet-101", use_fast=True)
if torch.cuda.is_available():
detector.model.to('cuda')
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
threshold = 90 # 置信度阈值
label_list = ["person", "car", "truck"]
label_color_dict = {}
def query_data(in_pil_img: Image.Image):
results = detector(in_pil_img)
print(f"检测结果:{results}")
return results
def get_font_size(box_width, min_size=10, max_size=48):
"""根据边界框宽度计算合适的字体大小"""
# 字体大小取决于边界框宽度,取值最小为24
font_size = max(24,int(box_width / 10))
return max(min(font_size, max_size), min_size)
def get_text_position(box, text_bbox):
"""根据边界框和文本边界框返回适当的位置"""
xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
# 尝试将文本放置在边界框上方,但如果空间不足,则放置在边界框内
if ymin - text_height >= 0:
return (xmin, ymin - text_height) # 上方
else:
return (xmin, ymin) # 内部
def get_annotated_image(in_pil_img):
draw = ImageDraw.Draw(in_pil_img)
in_results = query_data(in_pil_img)
for prediction in in_results:
box = prediction['box']
label = prediction['label']
score = round(prediction['score'] * 100, 1)
if score < threshold:
continue # 过滤掉低置信度的预测结果
if label not in label_list:
continue # 过滤掉不在允许显示的label列表中的预测结果
if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
color = choice(COLORS)
label_color_dict[label] = color
else:
color = label_color_dict[label]
# 计算字体大小
box_width = box['xmax'] - box['xmin']
font_size = get_font_size(box_width)
font = ImageFont.truetype(font="arial.ttf", size=font_size) # 确保你有可用的字体文件
# 获取文本边界框
text = f"{label}: {score}%"
text_bbox = draw.textbbox((0, 0), text, font=font)
# 绘制矩形
draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline=color, width=3)
# 添加文本
text_pos = get_text_position(box, text_bbox)
draw.text(text_pos, text, fill=color, font=font)
# 返回的是原始图像对象,它已经被修改了
return np.array(in_pil_img.convert('RGB'))
def process_video(input_video_path):
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise ValueError("无法打开输入视频文件")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 'mp4v' 编码器
output_dir = './output_videos' # 指定输出目录
os.makedirs(output_dir, exist_ok=True) # 确保输出目录存在
# 生成唯一文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_video_filename = f"output_{timestamp}.mp4"
output_video_path = os.path.join(output_dir, output_video_filename)
print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
while True:
ret, frame = cap.read()
if not ret:
break
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_frame)
print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息
annotated_frame = get_annotated_image(pil_image)
bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
# 确保帧的尺寸与视频输出一致
if bgr_frame.shape[:2] != (height, width):
bgr_frame = cv2.resize(bgr_frame, (width, height))
print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
out.write(bgr_frame)
cap.release()
out.release()
# 返回输出视频路径给 Gradio
return output_video_path
def change_threshold(value):
global threshold
threshold = value
return f"当前置信度阈值为{threshold}%"
def update_labels(selected_labels):
# 更新 label_list 以匹配用户的选择
global label_list
label_list = selected_labels
return selected_labels
with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
# 设置置信度阈值
threshold_slider = gr.Slider(minimum=0, maximum=100, value=threshold, step=1, label="置信度阈值")
textbox = gr.Textbox(value=f"当前置信度阈值为{threshold}%", label="置信度显示")
# 绑定滑块变化事件到change_threshold函数,同时设置输出为textbox
threshold_slider.change(fn=change_threshold, inputs=[threshold_slider], outputs=[textbox])
# 设置允许显示的label列表
label_checkboxes = gr.CheckboxGroup(choices=label_list, value=label_list, label="检测目标")
# 允许修改label_list
label_checkboxes.change(fn=update_labels, inputs=[label_checkboxes], outputs=[label_checkboxes])
with gr.Row():
input_video = gr.Video(label="输入视频")
detect_button = gr.Button("开始检测", variant="primary")
output_video = gr.Video(label="输出视频")
# 将process_video函数绑定到按钮点击事件,并将处理后的视频路径传递给output_video
detect_button.click(process_video, inputs=input_video, outputs=output_video)
demo.launch()
|