xfys commited on
Commit
dd085d6
1 Parent(s): 5ac7375

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import os
4
+ import track
5
+ import shutil
6
+ from pathlib import Path
7
+ from yolov5 import detect
8
+ from PIL import Image
9
+
10
+ # 目标检测
11
+ def Detect(image):
12
+ # 创建临时文件夹
13
+ temp_path = tempfile.TemporaryDirectory(dir="./")
14
+ temp_dir = temp_path.name
15
+ # 临时图片的路径
16
+ temp_image_path = os.path.join(temp_dir, f"temp.jpg")
17
+ # 存储临时图片
18
+ img = Image.fromarray(image)
19
+ img.save(temp_image_path)
20
+ # 结果图片的存储目录
21
+ temp_result_path = os.path.join(temp_dir, "tempresult")
22
+ # 对临时图片进行检测
23
+ detect.run(source=temp_image_path, data="test_image/FLIR.yaml", weights="weights/best.pt", project=f'./{temp_dir}',name = 'tempresult', hide_conf=False, conf_thres=0.35)
24
+ # 结果图片的路径
25
+ temp_result_path = os.path.join(temp_result_path, os.listdir(temp_result_path)[0])
26
+ # 读取结果图片
27
+ result_image = Image.open(temp_result_path).copy()
28
+ # 删除临时文件夹
29
+ temp_path.cleanup()
30
+ return result_image
31
+
32
+ # 候选图片
33
+ example_image= [
34
+ "./test_image/video-2SReBn5LtAkL5HMj2-frame-005072-MA7NCLQGoqq9aHaiL.jpg",
35
+ "./test_image/video-2rsjnZFyGQGeynfbv-frame-003708-6fPQbB7jtibwaYAE7.jpg",
36
+ "./test_image/video-2SReBn5LtAkL5HMj2-frame-000317-HTgPBFgZyPdwQnNvE.jpg",
37
+ "./test_image/video-jNQtRj6NGycZDEXpe-frame-002515-J3YntG8ntvZheKK3P.jpg",
38
+ "./test_image/video-kDDWXrnLSoSdHCZ7S-frame-003063-eaKjPvPskDPjenZ8S.jpg",
39
+ "./test_image/video-r68Yr9RPWEp5fW2ZF-frame-000333-X6K5iopqbmjKEsSqN.jpg"
40
+ ]
41
+
42
+ # 目标追踪
43
+ def Track(video, tracking_method):
44
+ # 存储临时视频的文件夹
45
+ temp_dir = "./temp"
46
+ # 先清空temp文件夹
47
+ shutil.rmtree("./temp")
48
+ os.mkdir("./temp")
49
+ # 获取视频的名字
50
+ video_name = os.path.basename(video)
51
+ # 对视频进行检测
52
+ track.run(source=video, yolo_weights=Path("weights/best2.pt"),reid_weights=Path("weights/osnet_x0_25_msmt17.pt") , project=Path(f'./{temp_dir}'),name = 'tempresult', tracking_method=tracking_method)
53
+ # 结果视频的路径
54
+ temp_result_path = os.path.join(f'./{temp_dir}', "tempresult", video_name)
55
+ # 返回结果视频的路径
56
+ return temp_result_path
57
+
58
+ # 候选视频
59
+ example_video= [
60
+ ["./video/5.mp4", None],
61
+ ["./video/bicyclecity.mp4", None],
62
+ ["./video/9.mp4", None],
63
+ ["./video/8.mp4", None],
64
+ ["./video/4.mp4", None],
65
+ ["./video/car.mp4", None],
66
+ ]
67
+
68
+ iface_Image = gr.Interface(fn=Detect,
69
+ inputs=gr.Image(label="上传一张红外图像,仅支持jpg格式"),
70
+ outputs=gr.Image(label="检测结果"),
71
+ examples=example_image)
72
+
73
+ iface_video = gr.Interface(fn=Track,
74
+ inputs=[gr.Video(label="上传段红外视频,仅支持mp4格式"),
75
+ gr.Radio(["bytetrack", "strongsort"],
76
+ label="track methond",
77
+ info="选择追踪器",
78
+ value="bytetrack")],
79
+ outputs=gr.Video(label="追踪结果"),
80
+ examples=example_video)
81
+
82
+ demo = gr.TabbedInterface([iface_video, iface_Image], tab_names=["目标追踪", "目标检测"], title="红外目标检测追踪")
83
+
84
+ demo.launch(share=True)
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+