Spaces:
Runtime error
Runtime error
isLinXu
commited on
Commit
·
699ee6d
1
Parent(s):
a96c370
update app
Browse files- app.py +92 -0
- requirements.txt +19 -0
app.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
os.system("pip install tensorflow")
|
4 |
+
os.system("pip install modelscope")
|
5 |
+
os.system("pip install thop")
|
6 |
+
os.system("pip install easydict ")
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import PIL.Image as Image
|
10 |
+
import torch
|
11 |
+
from modelscope.pipelines import pipeline
|
12 |
+
from modelscope.utils.constant import Tasks
|
13 |
+
import cv2
|
14 |
+
import numpy as np
|
15 |
+
import random
|
16 |
+
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
|
21 |
+
def object_detection(img_pil, confidence_threshold, device):
|
22 |
+
# 加载模型
|
23 |
+
p = pipeline(task='image-object-detection', model='damo/cv_tinynas_object-detection_damoyolo', device=device)
|
24 |
+
|
25 |
+
# 传入图片进行推理
|
26 |
+
result = p(img_pil)
|
27 |
+
# 读取图片
|
28 |
+
img_cv = cv2.cvtColor(np.asarray(img_pil), cv2.COLOR_RGB2BGR)
|
29 |
+
# 获取bbox和类别
|
30 |
+
scores = result['scores']
|
31 |
+
boxes = result['boxes']
|
32 |
+
labels = result['labels']
|
33 |
+
# 遍历每个bbox
|
34 |
+
for i in range(len(scores)):
|
35 |
+
# 只绘制置信度大于设定阈值的bbox
|
36 |
+
if scores[i] > confidence_threshold:
|
37 |
+
# 随机生成颜色
|
38 |
+
class_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
39 |
+
# 获取bbox坐标
|
40 |
+
x1, y1, x2, y2 = boxes[i]
|
41 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
42 |
+
# 绘制bbox
|
43 |
+
cv2.rectangle(img_cv, (x1, y1), (x2, y2), class_color, thickness=2)
|
44 |
+
# 绘制类别标签
|
45 |
+
label = f"{labels[i]}: {scores[i]:.2f}"
|
46 |
+
cv2.putText(img_cv, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, class_color, thickness=2)
|
47 |
+
img_pil = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
|
48 |
+
return img_pil
|
49 |
+
|
50 |
+
|
51 |
+
def download_test_image():
|
52 |
+
# Images
|
53 |
+
torch.hub.download_url_to_file(
|
54 |
+
'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
|
55 |
+
'bus.jpg')
|
56 |
+
torch.hub.download_url_to_file(
|
57 |
+
'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
|
58 |
+
'dogs.jpg')
|
59 |
+
torch.hub.download_url_to_file(
|
60 |
+
'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
|
61 |
+
'zidane.jpg')
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
download_test_image()
|
66 |
+
# 定义输入和输出
|
67 |
+
input_image = gr.inputs.Image(type='pil')
|
68 |
+
input_slide = gr.inputs.Slider(minimum=0, maximum=1, step=0.05, default=0.5, label="Confidence Threshold")
|
69 |
+
input_device = gr.inputs.Radio(["cpu", "cuda", "gpu"], default="cpu")
|
70 |
+
output_image = gr.outputs.Image(type='pil')
|
71 |
+
|
72 |
+
examples = [['bus.jpg', 0.45, "cpu"],
|
73 |
+
['dogs.jpg', 0.45, "cpu"],
|
74 |
+
['zidane.jpg', 0.45, "cpu"]]
|
75 |
+
title = "DAMO-YOLO web demo"
|
76 |
+
description = "<div align='center'><img src='https://raw.githubusercontent.com/tinyvision/DAMO-YOLO/master/assets/logo.png' width='800''/><div>" \
|
77 |
+
"<p style='text-align: center'><a href='https://github.com/tinyvision/DAMO-YOLO'>DAMO-YOLO</a> DAMO-YOLO DAMO-YOLO DAMO-YOLO:一种快速准确的目标检测方法,采用了一些新技术,包括 NAS 主干、高效的 RepGFPN、ZeroHead、AlignedOTA 和蒸馏增强。" \
|
78 |
+
"DAMO-YOLO: a fast and accurate object detection method with some new techs, including NAS backbones, efficient RepGFPN, ZeroHead, AlignedOTA, and distillation enhancement..</p>"
|
79 |
+
article = "<p style='text-align: center'><a href='https://github.com/tinyvision/DAMO-YOLO'>DAMO-YOLO</a></p>" \
|
80 |
+
"<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
|
81 |
+
|
82 |
+
# 创建 Gradio 接口并运行
|
83 |
+
gr.Interface(
|
84 |
+
fn=object_detection,
|
85 |
+
inputs=[
|
86 |
+
input_image, input_slide, input_device
|
87 |
+
],
|
88 |
+
outputs=output_image,
|
89 |
+
title=title,
|
90 |
+
examples=examples,
|
91 |
+
description=description, article=article
|
92 |
+
).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
|
2 |
+
ultralytics~=8.0.169
|
3 |
+
wget~=3.2
|
4 |
+
opencv-python~=4.6.0.66
|
5 |
+
numpy~=1.23.0
|
6 |
+
pillow~=9.4.0
|
7 |
+
gradio~=3.42.0
|
8 |
+
pyyaml~=6.0
|
9 |
+
wandb~=0.13.11
|
10 |
+
tqdm~=4.65.0
|
11 |
+
matplotlib~=3.7.1
|
12 |
+
pandas~=2.0.0
|
13 |
+
seaborn~=0.12.2
|
14 |
+
requests~=2.31.0
|
15 |
+
psutil~=5.9.4
|
16 |
+
thop~=0.1.1-2209072238
|
17 |
+
timm~=0.9.2
|
18 |
+
super-gradients~=3.2.0
|
19 |
+
openmim
|