using System.Collections.Generic; using Unity.Sentis; using UnityEngine; using UnityEngine.UI; using UnityEngine.Video; using Lays = Unity.Sentis.Layers; /* * YOLOv8n Inference Script * ======================== * * Place this script on the Main Camera. * * Place the yolov8n.sentis file and a *.mp4 video file in the Assets/StreamingAssets folder * Create a RawImage in your scene and set it as the displayImage field * Drag the classes.txt into the labelsAsset field * Add a reference to a sprite image for the bounding box and a font for the text * */ public class RunYOLO8n : MonoBehaviour { const string modelName = "yolov8n.sentis"; // Change this to the name of the video you put in StreamingAssets folder: const string videoName = "giraffes.mp4"; // Link the classes.txt here: public TextAsset labelsAsset; // Create a Raw Image in the scene and link it here: public RawImage displayImage; // Link to a bounding box texture here: public Sprite boxTexture; // Link to the font for the labels: public Font font; const BackendType backend = BackendType.GPUCompute; private Transform displayLocation; private Model model; private IWorker engine; private string[] labels; private RenderTexture targetRT; //Image size for the model private const int imageWidth = 640; private const int imageHeight = 640; //The number of classes in the model private const int numClasses = 80; private VideoPlayer video; List boxPool = new List(); [SerializeField, Range(0, 1)] float iouThreshold = 0.5f; [SerializeField, Range(0, 1)] float scoreThreshold = 0.5f; int maxOutputBoxes = 64; //For using tensor operators: Ops ops; //bounding box data public struct BoundingBox { public float centerX; public float centerY; public float width; public float height; public string label; } void Start() { Application.targetFrameRate = 60; Screen.orientation = ScreenOrientation.LandscapeLeft; ops = WorkerFactory.CreateOps(backend, null); //Parse neural net labels labels = labelsAsset.text.Split('\n'); LoadModel(); targetRT = new RenderTexture(imageWidth, imageHeight, 0); //Create image to display video displayLocation = displayImage.transform; //Create engine to run model engine = WorkerFactory.CreateWorker(backend, model); SetupInput(); } void LoadModel() { //Load model model = ModelLoader.Load(Application.streamingAssetsPath + "/" + modelName); //The classes are also stored here in JSON format: Debug.Log($"Class names: \n{model.Metadata["names"]}"); //We need to add some layers to choose the best boxes with the NMSLayer //Set constants model.AddConstant(new Lays.Constant("0", new int[] { 0 })); model.AddConstant(new Lays.Constant("1", new int[] { 1 })); model.AddConstant(new Lays.Constant("4", new int[] { 4 })); model.AddConstant(new Lays.Constant("classes_plus_4", new int[] { numClasses + 4 })); model.AddConstant(new Lays.Constant("maxOutputBoxes", new int[] { maxOutputBoxes })); model.AddConstant(new Lays.Constant("iouThreshold", new float[] { iouThreshold })); model.AddConstant(new Lays.Constant("scoreThreshold", new float[] { scoreThreshold })); //Add layers model.AddLayer(new Lays.Slice("boxCoords0", "output0", "0", "4", "1")); model.AddLayer(new Lays.Transpose("boxCoords", "boxCoords0", new int[] { 0, 2, 1 })); model.AddLayer(new Lays.Slice("scores0", "output0", "4", "classes_plus_4", "1")); model.AddLayer(new Lays.NonMaxSuppression("NMS", "boxCoords", "scores0", "maxOutputBoxes", "iouThreshold", "scoreThreshold", centerPointBox: Lays.CenterPointBox.Corners )); model.outputs.Clear(); model.AddOutput("boxCoords"); model.AddOutput("NMS"); } void SetupInput() { video = gameObject.AddComponent(); video.renderMode = VideoRenderMode.APIOnly; video.source = VideoSource.Url; video.url = Application.streamingAssetsPath + "/" + videoName; video.isLooping = true; video.Play(); } private void Update() { ExecuteML(); if (Input.GetKeyDown(KeyCode.Escape)) { Application.Quit(); } } public void ExecuteML() { ClearAnnotations(); if (video && video.texture) { float aspect = video.width * 1f / video.height; Graphics.Blit(video.texture, targetRT, new Vector2(1f / aspect, 1), new Vector2(0, 0)); displayImage.texture = targetRT; } else return; using var input = TextureConverter.ToTensor(targetRT, imageWidth, imageHeight, 3); engine.Execute(input); var boxCoords = engine.PeekOutput("boxCoords") as TensorFloat; var NMS = engine.PeekOutput("NMS") as TensorInt; using var boxIDs = ops.Slice(NMS, new int[] { 2 }, new int[] { 3 }, new int[] { 1 }, new int[] { 1 }); using var boxIDsFlat = boxIDs.ShallowReshape(new TensorShape(boxIDs.shape.length)) as TensorInt; using var output = ops.Gather(boxCoords, boxIDsFlat, 1); output.MakeReadable(); NMS.MakeReadable(); float displayWidth = displayImage.rectTransform.rect.width; float displayHeight = displayImage.rectTransform.rect.height; float scaleX = displayWidth / imageWidth; float scaleY = displayHeight / imageHeight; //Draw the bounding boxes for (int n = 0; n < output.shape[1]; n++) { var box = new BoundingBox { centerX = output[0, n, 0] * scaleX - displayWidth / 2, centerY = output[0, n, 1] * scaleY - displayHeight / 2, width = output[0, n, 2] * scaleX, height = output[0, n, 3] * scaleY, label = labels[(int)NMS[n, 1]] }; DrawBox(box, n); } } public void DrawBox(BoundingBox box , int id) { //Create the bounding box graphic or get from pool GameObject panel; if (id < boxPool.Count) { panel = boxPool[id]; panel.SetActive(true); } else { panel = CreateNewBox(Color.yellow); } //Set box position panel.transform.localPosition = new Vector3(box.centerX, -box.centerY); //Set box size RectTransform rt = panel.GetComponent(); rt.sizeDelta = new Vector2(box.width, box.height); //Set label text var label = panel.GetComponentInChildren(); label.text = box.label; } public GameObject CreateNewBox(Color color) { //Create the box and set image var panel = new GameObject("ObjectBox"); panel.AddComponent(); Image img = panel.AddComponent(); img.color = color; img.sprite = boxTexture; img.type = Image.Type.Sliced; panel.transform.SetParent(displayLocation, false); //Create the label var text = new GameObject("ObjectLabel"); text.AddComponent(); text.transform.SetParent(panel.transform, false); Text txt = text.AddComponent(); txt.font = font; txt.color = color; txt.fontSize = 40; txt.horizontalOverflow = HorizontalWrapMode.Overflow; RectTransform rt2 = text.GetComponent(); rt2.offsetMin = new Vector2(20, rt2.offsetMin.y); rt2.offsetMax = new Vector2(0, rt2.offsetMax.y); rt2.offsetMin = new Vector2(rt2.offsetMin.x, 0); rt2.offsetMax = new Vector2(rt2.offsetMax.x, 30); rt2.anchorMin = new Vector2(0, 0); rt2.anchorMax = new Vector2(1, 1); boxPool.Add(panel); return panel; } public void ClearAnnotations() { foreach(var box in boxPool) { box.SetActive(false); } } private void OnDestroy() { engine?.Dispose(); ops?.Dispose(); } }