ppo-Pyramids-Training
/
com.unity.ml-agents.extensions
/Tests
/Runtime
/Sensors
/PoseExtractorTests.cs
using System; | |
using UnityEngine; | |
using NUnit.Framework; | |
using Unity.MLAgents.Extensions.Sensors; | |
namespace Unity.MLAgents.Extensions.Tests.Sensors | |
{ | |
public class PoseExtractorTests | |
{ | |
class BasicPoseExtractor : PoseExtractor | |
{ | |
protected internal override Pose GetPoseAt(int index) | |
{ | |
return Pose.identity; | |
} | |
protected internal override Vector3 GetLinearVelocityAt(int index) | |
{ | |
return Vector3.zero; | |
} | |
} | |
class UselessPoseExtractor : BasicPoseExtractor | |
{ | |
public void Init(int[] parentIndices) | |
{ | |
Setup(parentIndices); | |
} | |
} | |
[ | ]|
public void TestEmptyExtractor() | |
{ | |
var poseExtractor = new UselessPoseExtractor(); | |
// These should be no-ops | |
poseExtractor.UpdateLocalSpacePoses(); | |
poseExtractor.UpdateModelSpacePoses(); | |
Assert.AreEqual(0, poseExtractor.NumPoses); | |
// Iterating through poses and velocities should be an empty loop | |
foreach (var pose in poseExtractor.GetEnabledModelSpacePoses()) | |
{ | |
throw new UnityAgentsException("This shouldn't happen"); | |
} | |
foreach (var pose in poseExtractor.GetEnabledLocalSpacePoses()) | |
{ | |
throw new UnityAgentsException("This shouldn't happen"); | |
} | |
foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities()) | |
{ | |
throw new UnityAgentsException("This shouldn't happen"); | |
} | |
foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities()) | |
{ | |
throw new UnityAgentsException("This shouldn't happen"); | |
} | |
// Getting a parent index should throw an index exception | |
Assert.Throws<NullReferenceException>( | |
() => poseExtractor.GetParentIndex(0) | |
); | |
// DisplayNodes should be empty | |
var displayNodes = poseExtractor.GetDisplayNodes(); | |
Assert.AreEqual(0, displayNodes.Count); | |
} | |
[ | ]|
public void TestSimpleExtractor() | |
{ | |
var poseExtractor = new UselessPoseExtractor(); | |
var parentIndices = new[] { -1, 0 }; | |
poseExtractor.Init(parentIndices); | |
Assert.AreEqual(2, poseExtractor.NumPoses); | |
} | |
/// <summary> | |
/// A simple "chain" hierarchy, where each object is parented to the one before it. | |
/// 0 <- 1 <- 2 <- ... | |
/// </summary> | |
class ChainPoseExtractor : PoseExtractor | |
{ | |
public Vector3 offset; | |
public ChainPoseExtractor(int size) | |
{ | |
var parents = new int[size]; | |
for (var i = 0; i < size; i++) | |
{ | |
parents[i] = i - 1; | |
} | |
Setup(parents); | |
} | |
protected internal override Pose GetPoseAt(int index) | |
{ | |
var rotation = Quaternion.identity; | |
var translation = offset + new Vector3(index, index, index); | |
return new Pose | |
{ | |
rotation = rotation, | |
position = translation | |
}; | |
} | |
protected internal override Vector3 GetLinearVelocityAt(int index) | |
{ | |
return Vector3.zero; | |
} | |
} | |
[ | ]|
public void TestChain() | |
{ | |
var size = 4; | |
var chain = new ChainPoseExtractor(size); | |
chain.offset = new Vector3(.5f, .75f, .333f); | |
chain.UpdateModelSpacePoses(); | |
chain.UpdateLocalSpacePoses(); | |
var modelPoseIndex = 0; | |
foreach (var modelSpace in chain.GetEnabledModelSpacePoses()) | |
{ | |
if (modelPoseIndex == 0) | |
{ | |
// Root transforms are currently always the identity. | |
Assert.IsTrue(modelSpace == Pose.identity); | |
} | |
else | |
{ | |
var expectedModelTranslation = new Vector3(modelPoseIndex, modelPoseIndex, modelPoseIndex); | |
Assert.IsTrue(expectedModelTranslation == modelSpace.position); | |
} | |
modelPoseIndex++; | |
} | |
Assert.AreEqual(size, modelPoseIndex); | |
var localPoseIndex = 0; | |
foreach (var localSpace in chain.GetEnabledLocalSpacePoses()) | |
{ | |
if (localPoseIndex == 0) | |
{ | |
// Root transforms are currently always the identity. | |
Assert.IsTrue(localSpace == Pose.identity); | |
} | |
else | |
{ | |
var expectedLocalTranslation = new Vector3(1, 1, 1); | |
Assert.IsTrue(expectedLocalTranslation == localSpace.position, $"{expectedLocalTranslation} != {localSpace.position}"); | |
} | |
localPoseIndex++; | |
} | |
Assert.AreEqual(size, localPoseIndex); | |
} | |
[ | ]|
public void TestChainDisplayNodes() | |
{ | |
var size = 4; | |
var chain = new ChainPoseExtractor(size); | |
var displayNodes = chain.GetDisplayNodes(); | |
Assert.AreEqual(size, displayNodes.Count); | |
for (var i = 0; i < size; i++) | |
{ | |
var displayNode = displayNodes[i]; | |
Assert.AreEqual(i, displayNode.OriginalIndex); | |
Assert.AreEqual(null, displayNode.NodeObject); | |
Assert.AreEqual(i, displayNode.Depth); | |
Assert.AreEqual(true, displayNode.Enabled); | |
} | |
} | |
[ | ]|
public void TestDisplayNodesLoop() | |
{ | |
// Degenerate case with a loop | |
var poseExtractor = new UselessPoseExtractor(); | |
poseExtractor.Init(new[] { -1, 2, 1 }); | |
// This just shouldn't blow up | |
poseExtractor.GetDisplayNodes(); | |
// Self-loop | |
poseExtractor.Init(new[] { -1, 1 }); | |
// This just shouldn't blow up | |
poseExtractor.GetDisplayNodes(); | |
} | |
class BadPoseExtractor : BasicPoseExtractor | |
{ | |
public BadPoseExtractor() | |
{ | |
var size = 2; | |
var parents = new int[size]; | |
// Parents are intentionally invalid - expect -1 at root | |
for (var i = 0; i < size; i++) | |
{ | |
parents[i] = i; | |
} | |
Setup(parents); | |
} | |
} | |
[ | ]|
public void TestExpectedRoot() | |
{ | |
Assert.Throws<UnityAgentsException>(() => | |
{ | |
var unused = new BadPoseExtractor(); | |
}); | |
} | |
} | |
public class PoseExtensionTests | |
{ | |
[ | ]|
public void TestInverse() | |
{ | |
Pose t = new Pose | |
{ | |
rotation = Quaternion.AngleAxis(23.0f, new Vector3(1, 1, 1).normalized), | |
position = new Vector3(-1.0f, 2.0f, 3.0f) | |
}; | |
var inverseT = t.Inverse(); | |
var product = inverseT.Multiply(t); | |
Assert.IsTrue(Vector3.zero == product.position); | |
Assert.IsTrue(Quaternion.identity == product.rotation); | |
Assert.IsTrue(Pose.identity == product); | |
} | |
} | |
} | |