|
using System.Collections.Generic; |
|
using Unity.Sentis; |
|
using UnityEngine; |
|
using UnityEngine.UI; |
|
using UnityEngine.Video; |
|
using Lays = Unity.Sentis.Layers; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public class RunYOLO8n : MonoBehaviour |
|
{ |
|
const string modelName = "yolov8n.sentis"; |
|
|
|
const string videoName = "giraffes.mp4"; |
|
|
|
public TextAsset labelsAsset; |
|
|
|
public RawImage displayImage; |
|
|
|
public Sprite boxTexture; |
|
|
|
public Font font; |
|
|
|
const BackendType backend = BackendType.GPUCompute; |
|
|
|
private Transform displayLocation; |
|
private Model model; |
|
private IWorker engine; |
|
private string[] labels; |
|
private RenderTexture targetRT; |
|
|
|
|
|
|
|
private const int imageWidth = 640; |
|
private const int imageHeight = 640; |
|
|
|
|
|
private const int numClasses = 80; |
|
|
|
private VideoPlayer video; |
|
|
|
List<GameObject> boxPool = new List<GameObject>(); |
|
|
|
[SerializeField, Range(0, 1)] float iouThreshold = 0.5f; |
|
[SerializeField, Range(0, 1)] float scoreThreshold = 0.5f; |
|
int maxOutputBoxes = 64; |
|
|
|
|
|
Ops ops; |
|
|
|
|
|
public struct BoundingBox |
|
{ |
|
public float centerX; |
|
public float centerY; |
|
public float width; |
|
public float height; |
|
public string label; |
|
} |
|
|
|
|
|
void Start() |
|
{ |
|
Application.targetFrameRate = 60; |
|
Screen.orientation = ScreenOrientation.LandscapeLeft; |
|
|
|
ops = WorkerFactory.CreateOps(backend, null); |
|
|
|
|
|
labels = labelsAsset.text.Split('\n'); |
|
|
|
LoadModel(); |
|
|
|
targetRT = new RenderTexture(imageWidth, imageHeight, 0); |
|
|
|
|
|
displayLocation = displayImage.transform; |
|
|
|
|
|
engine = WorkerFactory.CreateWorker(backend, model); |
|
|
|
SetupInput(); |
|
} |
|
|
|
void LoadModel() |
|
{ |
|
|
|
model = ModelLoader.Load(Application.streamingAssetsPath + "/" + modelName); |
|
|
|
|
|
Debug.Log($"Class names: \n{model.Metadata["names"]}"); |
|
|
|
|
|
|
|
|
|
model.AddConstant(new Lays.Constant("0", new int[] { 0 })); |
|
model.AddConstant(new Lays.Constant("1", new int[] { 1 })); |
|
model.AddConstant(new Lays.Constant("4", new int[] { 4 })); |
|
model.AddConstant(new Lays.Constant("classes_plus_4", new int[] { numClasses + 4 })); |
|
model.AddConstant(new Lays.Constant("maxOutputBoxes", new int[] { maxOutputBoxes })); |
|
model.AddConstant(new Lays.Constant("iouThreshold", new float[] { iouThreshold })); |
|
model.AddConstant(new Lays.Constant("scoreThreshold", new float[] { scoreThreshold })); |
|
|
|
|
|
model.AddLayer(new Lays.Slice("boxCoords0", "output0", "0", "4", "1")); |
|
model.AddLayer(new Lays.Transpose("boxCoords", "boxCoords0", new int[] { 0, 2, 1 })); |
|
model.AddLayer(new Lays.Slice("scores0", "output0", "4", "classes_plus_4", "1")); |
|
|
|
model.AddLayer(new Lays.NonMaxSuppression("NMS", "boxCoords", "scores0", |
|
"maxOutputBoxes", "iouThreshold", "scoreThreshold", |
|
centerPointBox: Lays.CenterPointBox.Corners |
|
)); |
|
|
|
model.outputs.Clear(); |
|
model.AddOutput("boxCoords"); |
|
model.AddOutput("NMS"); |
|
} |
|
|
|
void SetupInput() |
|
{ |
|
video = gameObject.AddComponent<VideoPlayer>(); |
|
video.renderMode = VideoRenderMode.APIOnly; |
|
video.source = VideoSource.Url; |
|
video.url = Application.streamingAssetsPath + "/" + videoName; |
|
video.isLooping = true; |
|
video.Play(); |
|
} |
|
|
|
private void Update() |
|
{ |
|
ExecuteML(); |
|
|
|
if (Input.GetKeyDown(KeyCode.Escape)) |
|
{ |
|
Application.Quit(); |
|
} |
|
} |
|
|
|
public void ExecuteML() |
|
{ |
|
ClearAnnotations(); |
|
|
|
if (video && video.texture) |
|
{ |
|
float aspect = video.width * 1f / video.height; |
|
Graphics.Blit(video.texture, targetRT, new Vector2(1f / aspect, 1), new Vector2(0, 0)); |
|
displayImage.texture = targetRT; |
|
} |
|
else return; |
|
|
|
using var input = TextureConverter.ToTensor(targetRT, imageWidth, imageHeight, 3); |
|
engine.Execute(input); |
|
|
|
var boxCoords = engine.PeekOutput("boxCoords") as TensorFloat; |
|
var NMS = engine.PeekOutput("NMS") as TensorInt; |
|
|
|
using var boxIDs = ops.Slice(NMS, new int[] { 2 }, new int[] { 3 }, new int[] { 1 }, new int[] { 1 }); |
|
using var boxIDsFlat = boxIDs.ShallowReshape(new TensorShape(boxIDs.shape.length)) as TensorInt; |
|
using var output = ops.Gather(boxCoords, boxIDsFlat, 1); |
|
|
|
output.MakeReadable(); |
|
NMS.MakeReadable(); |
|
|
|
float displayWidth = displayImage.rectTransform.rect.width; |
|
float displayHeight = displayImage.rectTransform.rect.height; |
|
|
|
float scaleX = displayWidth / imageWidth; |
|
float scaleY = displayHeight / imageHeight; |
|
|
|
|
|
for (int n = 0; n < output.shape[1]; n++) |
|
{ |
|
var box = new BoundingBox |
|
{ |
|
centerX = output[0, n, 0] * scaleX - displayWidth / 2, |
|
centerY = output[0, n, 1] * scaleY - displayHeight / 2, |
|
width = output[0, n, 2] * scaleX, |
|
height = output[0, n, 3] * scaleY, |
|
label = labels[(int)NMS[n, 1]] |
|
}; |
|
DrawBox(box, n); |
|
} |
|
} |
|
|
|
public void DrawBox(BoundingBox box , int id) |
|
{ |
|
|
|
GameObject panel; |
|
if (id < boxPool.Count) |
|
{ |
|
panel = boxPool[id]; |
|
panel.SetActive(true); |
|
} |
|
else |
|
{ |
|
panel = CreateNewBox(Color.yellow); |
|
} |
|
|
|
panel.transform.localPosition = new Vector3(box.centerX, -box.centerY); |
|
|
|
|
|
RectTransform rt = panel.GetComponent<RectTransform>(); |
|
rt.sizeDelta = new Vector2(box.width, box.height); |
|
|
|
|
|
var label = panel.GetComponentInChildren<Text>(); |
|
label.text = box.label; |
|
} |
|
|
|
public GameObject CreateNewBox(Color color) |
|
{ |
|
|
|
|
|
var panel = new GameObject("ObjectBox"); |
|
panel.AddComponent<CanvasRenderer>(); |
|
Image img = panel.AddComponent<Image>(); |
|
img.color = color; |
|
img.sprite = boxTexture; |
|
img.type = Image.Type.Sliced; |
|
panel.transform.SetParent(displayLocation, false); |
|
|
|
|
|
|
|
var text = new GameObject("ObjectLabel"); |
|
text.AddComponent<CanvasRenderer>(); |
|
text.transform.SetParent(panel.transform, false); |
|
Text txt = text.AddComponent<Text>(); |
|
txt.font = font; |
|
txt.color = color; |
|
txt.fontSize = 40; |
|
txt.horizontalOverflow = HorizontalWrapMode.Overflow; |
|
|
|
RectTransform rt2 = text.GetComponent<RectTransform>(); |
|
rt2.offsetMin = new Vector2(20, rt2.offsetMin.y); |
|
rt2.offsetMax = new Vector2(0, rt2.offsetMax.y); |
|
rt2.offsetMin = new Vector2(rt2.offsetMin.x, 0); |
|
rt2.offsetMax = new Vector2(rt2.offsetMax.x, 30); |
|
rt2.anchorMin = new Vector2(0, 0); |
|
rt2.anchorMax = new Vector2(1, 1); |
|
|
|
boxPool.Add(panel); |
|
return panel; |
|
} |
|
|
|
public void ClearAnnotations() |
|
{ |
|
foreach(var box in boxPool) |
|
{ |
|
box.SetActive(false); |
|
} |
|
} |
|
|
|
private void OnDestroy() |
|
{ |
|
engine?.Dispose(); |
|
ops?.Dispose(); |
|
} |
|
} |
|
|