Spaces:
Runtime error
Runtime error
Niv Sardi
commited on
Commit
•
e04a33d
1
Parent(s):
26ef429
entities: make read_entities return entities
Browse filesSigned-off-by: Niv Sardi <xaiki@evilgiggle.com>
- python/entity.py +1 -1
- python/write_data.py +1 -1
- train.sh +60 -0
python/entity.py
CHANGED
@@ -7,7 +7,7 @@ from common import defaults
|
|
7 |
def read_entities(fn = defaults.MAIN_CSV_PATH):
|
8 |
with open(fn, newline='') as csvfile:
|
9 |
reader = csv.DictReader(csvfile)
|
10 |
-
bcos = { d['bco']:d for
|
11 |
return bcos
|
12 |
|
13 |
class Entity(NamedTuple):
|
|
|
7 |
def read_entities(fn = defaults.MAIN_CSV_PATH):
|
8 |
with open(fn, newline='') as csvfile:
|
9 |
reader = csv.DictReader(csvfile)
|
10 |
+
bcos = { d['bco']:Entity.from_dict(d) for d in reader}
|
11 |
return bcos
|
12 |
|
13 |
class Entity(NamedTuple):
|
python/write_data.py
CHANGED
@@ -5,7 +5,7 @@ import argparse
|
|
5 |
from common import defaults
|
6 |
|
7 |
def gen_data_yaml(bcos, datapath='../data'):
|
8 |
-
names = [f"{d
|
9 |
return f'''
|
10 |
# this file is autogenerated by write_data.py
|
11 |
|
|
|
5 |
from common import defaults
|
6 |
|
7 |
def gen_data_yaml(bcos, datapath='../data'):
|
8 |
+
names = [f"{d.name}" for d in bcos.values()]
|
9 |
return f'''
|
10 |
# this file is autogenerated by write_data.py
|
11 |
|
train.sh
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
|
3 |
+
set -e
|
4 |
+
|
5 |
+
PY=python3
|
6 |
+
DATA_PATH=${PWD}/data
|
7 |
+
|
8 |
+
${PY} ./python/write_data.py ./data/entities.csv --data ${DATA_PATH} > ${DATA_PATH}/data.yaml
|
9 |
+
grep nc ${DATA_PATH}/data.yaml > ${DATA_PATH}/custom_yolov5s.yaml
|
10 |
+
cat <<EOF >> ${DATA_PATH}/custom_yolov5s.yaml
|
11 |
+
# parameters
|
12 |
+
nc: {num_classes} # number of classes # CHANGED HERE
|
13 |
+
depth_multiple: 0.33 # model depth multiple
|
14 |
+
width_multiple: 0.50 # layer channel multiple
|
15 |
+
|
16 |
+
# anchors
|
17 |
+
anchors:
|
18 |
+
- [10,13, 16,30, 33,23] # P3/8
|
19 |
+
- [30,61, 62,45, 59,119] # P4/16
|
20 |
+
- [116,90, 156,198, 373,326] # P5/32
|
21 |
+
|
22 |
+
# YOLOv5 backbone
|
23 |
+
backbone:
|
24 |
+
# [from, number, module, args]
|
25 |
+
[[-1, 1, Focus, [64, 3]], # 0-P1/2
|
26 |
+
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
27 |
+
[-1, 3, BottleneckCSP, [128]],
|
28 |
+
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
29 |
+
[-1, 9, BottleneckCSP, [256]],
|
30 |
+
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
31 |
+
[-1, 9, BottleneckCSP, [512]],
|
32 |
+
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
33 |
+
[-1, 1, SPP, [1024, [5, 9, 13]]],
|
34 |
+
[-1, 3, BottleneckCSP, [1024, False]], # 9
|
35 |
+
]
|
36 |
+
|
37 |
+
# YOLOv5 head
|
38 |
+
head:
|
39 |
+
[[-1, 1, Conv, [512, 1, 1]],
|
40 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
41 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
42 |
+
[-1, 3, BottleneckCSP, [512, False]], # 13
|
43 |
+
|
44 |
+
[-1, 1, Conv, [256, 1, 1]],
|
45 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
46 |
+
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
47 |
+
[-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small)
|
48 |
+
|
49 |
+
[-1, 1, Conv, [256, 3, 2]],
|
50 |
+
[[-1, 14], 1, Concat, [1]], # cat head P4
|
51 |
+
[-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium)
|
52 |
+
|
53 |
+
[-1, 1, Conv, [512, 3, 2]],
|
54 |
+
[[-1, 10], 1, Concat, [1]], # cat head P5
|
55 |
+
[-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large)
|
56 |
+
|
57 |
+
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
58 |
+
]
|
59 |
+
EOF
|
60 |
+
(cd yolov5; ${PY} train.py --img 416 --batch 80 --epochs 1000 --data ${DATA_PATH}/data.yaml --cfg ${DATA_PATH}/custom_yolov5s.yaml --weights '')
|