yolov5_tracking / app.py
xfys's picture
Update app.py
4ce3394
raw
history blame
3.47 kB
import gradio as gr
import tempfile
import os
import track
import shutil
from pathlib import Path
from yolov5 import detect
from PIL import Image
os.system("pip install Cython")
os.system("pip install cython-bbox")
# 目标检测
def Detect(image):
# 创建临时文件夹
temp_path = tempfile.TemporaryDirectory(dir="./")
temp_dir = temp_path.name
# 临时图片的路径
temp_image_path = os.path.join(temp_dir, f"temp.jpg")
# 存储临时图片
img = Image.fromarray(image)
img.save(temp_image_path)
# 结果图片的存储目录
temp_result_path = os.path.join(temp_dir, "tempresult")
# 对临时图片进行检测
detect.run(source=temp_image_path, data="test_image/FLIR.yaml", weights="weights/best.pt", project=f'./{temp_dir}',name = 'tempresult', hide_conf=False, conf_thres=0.35)
# 结果图片的路径
temp_result_path = os.path.join(temp_result_path, os.listdir(temp_result_path)[0])
# 读取结果图片
result_image = Image.open(temp_result_path).copy()
# 删除临时文件夹
temp_path.cleanup()
return result_image
# 候选图片
example_image= [
"./test_image/video-2SReBn5LtAkL5HMj2-frame-005072-MA7NCLQGoqq9aHaiL.jpg",
"./test_image/video-2rsjnZFyGQGeynfbv-frame-003708-6fPQbB7jtibwaYAE7.jpg",
"./test_image/video-2SReBn5LtAkL5HMj2-frame-000317-HTgPBFgZyPdwQnNvE.jpg",
"./test_image/video-jNQtRj6NGycZDEXpe-frame-002515-J3YntG8ntvZheKK3P.jpg",
"./test_image/video-kDDWXrnLSoSdHCZ7S-frame-003063-eaKjPvPskDPjenZ8S.jpg",
"./test_image/video-r68Yr9RPWEp5fW2ZF-frame-000333-X6K5iopqbmjKEsSqN.jpg"
]
# 目标追踪
def Track(video, tracking_method):
# 存储临时视频的文件夹
temp_dir = "./temp"
# 先清空temp文件夹
shutil.rmtree("./temp")
os.mkdir("./temp")
# 获取视频的名字
video_name = os.path.basename(video)
# 对视频进行检测
track.run(source=video, yolo_weights=Path("weights/best2.pt"),reid_weights=Path("weights/osnet_x0_25_msmt17.pt") , project=Path(f'./{temp_dir}'),name = 'tempresult', tracking_method=tracking_method)
# 结果视频的路径
temp_result_path = os.path.join(f'./{temp_dir}', "tempresult", video_name)
# 返回结果视频的路径
return temp_result_path
# 候选视频
example_video= [
["./video/5.mp4", None],
["./video/bicyclecity.mp4", None],
["./video/9.mp4", None],
["./video/8.mp4", None],
["./video/4.mp4", None],
["./video/car.mp4", None],
]
iface_Image = gr.Interface(fn=Detect,
inputs=gr.Image(label="上传一张红外图像,仅支持jpg格式"),
outputs=gr.Image(label="检测结果"),
examples=example_image)
iface_video = gr.Interface(fn=Track,
inputs=[gr.Video(label="上传段红外视频,仅支持mp4格式"),
gr.Radio(["bytetrack", "strongsort"],
label="track methond",
info="选择追踪器",
value="bytetrack")],
outputs=gr.Video(label="追踪结果"),
examples=example_video)
demo = gr.TabbedInterface([iface_video, iface_Image], tab_names=["目标追踪", "目标检测"], title="红外目标检测追踪")
demo.launch(share=True)