Paul Bird commited on
Commit
62ee322
1 Parent(s): 61644a5

Upload RunYOLO8n.cs

Browse files
Files changed (1) hide show
  1. RunYOLO8n.cs +12 -4
RunYOLO8n.cs CHANGED
@@ -108,23 +108,28 @@ public class RunYOLO8n : MonoBehaviour
108
  model.AddConstant(new Lays.Constant("0", new int[] { 0 }));
109
  model.AddConstant(new Lays.Constant("1", new int[] { 1 }));
110
  model.AddConstant(new Lays.Constant("4", new int[] { 4 }));
 
 
111
  model.AddConstant(new Lays.Constant("classes_plus_4", new int[] { numClasses + 4 }));
112
  model.AddConstant(new Lays.Constant("maxOutputBoxes", new int[] { maxOutputBoxes }));
113
  model.AddConstant(new Lays.Constant("iouThreshold", new float[] { iouThreshold }));
114
  model.AddConstant(new Lays.Constant("scoreThreshold", new float[] { scoreThreshold }));
115
 
116
  //Add layers
117
- model.AddLayer(new Lays.Slice("boxCoords0", "output0", "0", "4", "1"));
118
  model.AddLayer(new Lays.Transpose("boxCoords", "boxCoords0", new int[] { 0, 2, 1 }));
119
- model.AddLayer(new Lays.Slice("scores0", "output0", "4", "classes_plus_4", "1"));
 
 
120
 
121
- model.AddLayer(new Lays.NonMaxSuppression("NMS", "boxCoords", "scores0",
122
  "maxOutputBoxes", "iouThreshold", "scoreThreshold",
123
  centerPointBox: Lays.CenterPointBox.Center
124
  ));
125
 
126
  model.outputs.Clear();
127
  model.AddOutput("boxCoords");
 
128
  model.AddOutput("NMS");
129
  }
130
 
@@ -165,12 +170,15 @@ public class RunYOLO8n : MonoBehaviour
165
 
166
  var boxCoords = engine.PeekOutput("boxCoords") as TensorFloat;
167
  var NMS = engine.PeekOutput("NMS") as TensorInt;
 
168
 
169
  using var boxIDs = ops.Slice(NMS, new int[] { 2 }, new int[] { 3 }, new int[] { 1 }, new int[] { 1 });
170
  using var boxIDsFlat = boxIDs.ShallowReshape(new TensorShape(boxIDs.shape.length)) as TensorInt;
171
  using var output = ops.Gather(boxCoords, boxIDsFlat, 1);
 
172
 
173
  output.MakeReadable();
 
174
  NMS.MakeReadable();
175
 
176
  float displayWidth = displayImage.rectTransform.rect.width;
@@ -188,7 +196,7 @@ public class RunYOLO8n : MonoBehaviour
188
  centerY = output[0, n, 1] * scaleY - displayHeight / 2,
189
  width = output[0, n, 2] * scaleX,
190
  height = output[0, n, 3] * scaleY,
191
- label = labels[(int)NMS[n, 1]]
192
  };
193
  DrawBox(box, n);
194
  }
 
108
  model.AddConstant(new Lays.Constant("0", new int[] { 0 }));
109
  model.AddConstant(new Lays.Constant("1", new int[] { 1 }));
110
  model.AddConstant(new Lays.Constant("4", new int[] { 4 }));
111
+
112
+
113
  model.AddConstant(new Lays.Constant("classes_plus_4", new int[] { numClasses + 4 }));
114
  model.AddConstant(new Lays.Constant("maxOutputBoxes", new int[] { maxOutputBoxes }));
115
  model.AddConstant(new Lays.Constant("iouThreshold", new float[] { iouThreshold }));
116
  model.AddConstant(new Lays.Constant("scoreThreshold", new float[] { scoreThreshold }));
117
 
118
  //Add layers
119
+ model.AddLayer(new Lays.Slice("boxCoords0", "output0", "0", "4", "1"));
120
  model.AddLayer(new Lays.Transpose("boxCoords", "boxCoords0", new int[] { 0, 2, 1 }));
121
+ model.AddLayer(new Lays.Slice("scores0", "output0", "4", "classes_plus_4", "1"));
122
+ model.AddLayer(new Lays.ReduceMax("scores", new[] { "scores0", "1" }));
123
+ model.AddLayer(new Lays.ArgMax("classIDs", "scores0", 1));
124
 
125
+ model.AddLayer(new Lays.NonMaxSuppression("NMS", "boxCoords", "scores",
126
  "maxOutputBoxes", "iouThreshold", "scoreThreshold",
127
  centerPointBox: Lays.CenterPointBox.Center
128
  ));
129
 
130
  model.outputs.Clear();
131
  model.AddOutput("boxCoords");
132
+ model.AddOutput("classIDs");
133
  model.AddOutput("NMS");
134
  }
135
 
 
170
 
171
  var boxCoords = engine.PeekOutput("boxCoords") as TensorFloat;
172
  var NMS = engine.PeekOutput("NMS") as TensorInt;
173
+ var classIDs = engine.PeekOutput("classIDs") as TensorInt;
174
 
175
  using var boxIDs = ops.Slice(NMS, new int[] { 2 }, new int[] { 3 }, new int[] { 1 }, new int[] { 1 });
176
  using var boxIDsFlat = boxIDs.ShallowReshape(new TensorShape(boxIDs.shape.length)) as TensorInt;
177
  using var output = ops.Gather(boxCoords, boxIDsFlat, 1);
178
+ using var labelIDs = ops.Gather(classIDs, boxIDsFlat, 2);
179
 
180
  output.MakeReadable();
181
+ labelIDs.MakeReadable();
182
  NMS.MakeReadable();
183
 
184
  float displayWidth = displayImage.rectTransform.rect.width;
 
196
  centerY = output[0, n, 1] * scaleY - displayHeight / 2,
197
  width = output[0, n, 2] * scaleX,
198
  height = output[0, n, 3] * scaleY,
199
+ label = labels[labelIDs[0, 0,n]],
200
  };
201
  DrawBox(box, n);
202
  }