File size: 11,309 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;
[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
public class CrawlerAgent : Agent
{
[Header("Walk Speed")]
[Range(0.1f, m_maxWalkingSpeed)]
[SerializeField]
[Tooltip(
"The speed the agent will try to match.\n\n" +
"TRAINING:\n" +
"For VariableSpeed envs, this value will randomize at the start of each training episode.\n" +
"Otherwise the agent will try to match the speed set here.\n\n" +
"INFERENCE:\n" +
"During inference, VariableSpeed agents will modify their behavior based on this value " +
"whereas the CrawlerDynamic & CrawlerStatic agents will run at the speed specified during training "
)]
//The walking speed to try and achieve
private float m_TargetWalkingSpeed = m_maxWalkingSpeed;
const float m_maxWalkingSpeed = 15; //The max walking speed
//The current target walking speed. Clamped because a value of zero will cause NaNs
public float TargetWalkingSpeed
{
get { return m_TargetWalkingSpeed; }
set { m_TargetWalkingSpeed = Mathf.Clamp(value, .1f, m_maxWalkingSpeed); }
}
//The direction an agent will walk during training.
[Header("Target To Walk Towards")]
public Transform TargetPrefab; //Target prefab to use in Dynamic envs
private Transform m_Target; //Target the agent will walk towards during training.
[Header("Body Parts")][Space(10)] public Transform body;
public Transform leg0Upper;
public Transform leg0Lower;
public Transform leg1Upper;
public Transform leg1Lower;
public Transform leg2Upper;
public Transform leg2Lower;
public Transform leg3Upper;
public Transform leg3Lower;
//This will be used as a stabilized model space reference point for observations
//Because ragdolls can move erratically during training, using a stabilized reference transform improves learning
OrientationCubeController m_OrientationCube;
//The indicator graphic gameobject that points towards the target
DirectionIndicator m_DirectionIndicator;
JointDriveController m_JdController;
[Header("Foot Grounded Visualization")]
[Space(10)]
public bool useFootGroundedVisualization;
public MeshRenderer foot0;
public MeshRenderer foot1;
public MeshRenderer foot2;
public MeshRenderer foot3;
public Material groundedMaterial;
public Material unGroundedMaterial;
public override void Initialize()
{
SpawnTarget(TargetPrefab, transform.position); //spawn target
m_OrientationCube = GetComponentInChildren<OrientationCubeController>();
m_DirectionIndicator = GetComponentInChildren<DirectionIndicator>();
m_JdController = GetComponent<JointDriveController>();
//Setup each body part
m_JdController.SetupBodyPart(body);
m_JdController.SetupBodyPart(leg0Upper);
m_JdController.SetupBodyPart(leg0Lower);
m_JdController.SetupBodyPart(leg1Upper);
m_JdController.SetupBodyPart(leg1Lower);
m_JdController.SetupBodyPart(leg2Upper);
m_JdController.SetupBodyPart(leg2Lower);
m_JdController.SetupBodyPart(leg3Upper);
m_JdController.SetupBodyPart(leg3Lower);
}
/// <summary>
/// Spawns a target prefab at pos
/// </summary>
/// <param name="prefab"></param>
/// <param name="pos"></param>
void SpawnTarget(Transform prefab, Vector3 pos)
{
m_Target = Instantiate(prefab, pos, Quaternion.identity, transform.parent);
}
/// <summary>
/// Loop over body parts and reset them to initial conditions.
/// </summary>
public override void OnEpisodeBegin()
{
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
bodyPart.Reset(bodyPart);
}
//Random start rotation to help generalize
body.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);
UpdateOrientationObjects();
//Set our goal walking speed
TargetWalkingSpeed = Random.Range(0.1f, m_maxWalkingSpeed);
}
/// <summary>
/// Add relevant information on each body part to observations.
/// </summary>
public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
{
//GROUND CHECK
sensor.AddObservation(bp.groundContact.touchingGround); // Is this bp touching the ground
if (bp.rb.transform != body)
{
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}
/// <summary>
/// Loop over body parts to add them to observation.
/// </summary>
public override void CollectObservations(VectorSensor sensor)
{
var cubeForward = m_OrientationCube.transform.forward;
//velocity we want to match
var velGoal = cubeForward * TargetWalkingSpeed;
//ragdoll's avg vel
var avgVel = GetAvgVelocity();
//current ragdoll velocity. normalized
sensor.AddObservation(Vector3.Distance(velGoal, avgVel));
//avg body vel relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(avgVel));
//vel goal relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(velGoal));
//rotation delta
sensor.AddObservation(Quaternion.FromToRotation(body.forward, cubeForward));
//Add pos of target relative to orientation cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(m_Target.transform.position));
RaycastHit hit;
float maxRaycastDist = 10;
if (Physics.Raycast(body.position, Vector3.down, out hit, maxRaycastDist))
{
sensor.AddObservation(hit.distance / maxRaycastDist);
}
else
sensor.AddObservation(1);
foreach (var bodyPart in m_JdController.bodyPartsList)
{
CollectObservationBodyPart(bodyPart, sensor);
}
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// The dictionary with all the body parts in it are in the jdController
var bpDict = m_JdController.bodyPartsDict;
var continuousActions = actionBuffers.ContinuousActions;
var i = -1;
// Pick a new target joint rotation
bpDict[leg0Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg1Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg2Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg3Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg0Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg1Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg2Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg3Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
// Update joint strength
bpDict[leg0Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg1Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg2Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg3Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg0Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg1Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg2Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg3Lower].SetJointStrength(continuousActions[++i]);
}
void FixedUpdate()
{
UpdateOrientationObjects();
// If enabled the feet will light up green when the foot is grounded.
// This is just a visualization and isn't necessary for function
if (useFootGroundedVisualization)
{
foot0.material = m_JdController.bodyPartsDict[leg0Lower].groundContact.touchingGround
? groundedMaterial
: unGroundedMaterial;
foot1.material = m_JdController.bodyPartsDict[leg1Lower].groundContact.touchingGround
? groundedMaterial
: unGroundedMaterial;
foot2.material = m_JdController.bodyPartsDict[leg2Lower].groundContact.touchingGround
? groundedMaterial
: unGroundedMaterial;
foot3.material = m_JdController.bodyPartsDict[leg3Lower].groundContact.touchingGround
? groundedMaterial
: unGroundedMaterial;
}
var cubeForward = m_OrientationCube.transform.forward;
// Set reward for this step according to mixture of the following elements.
// a. Match target speed
//This reward will approach 1 if it matches perfectly and approach zero as it deviates
var matchSpeedReward = GetMatchingVelocityReward(cubeForward * TargetWalkingSpeed, GetAvgVelocity());
// b. Rotation alignment with target direction.
//This reward will approach 1 if it faces the target direction perfectly and approach zero as it deviates
var lookAtTargetReward = (Vector3.Dot(cubeForward, body.forward) + 1) * .5F;
AddReward(matchSpeedReward * lookAtTargetReward);
}
/// <summary>
/// Update OrientationCube and DirectionIndicator
/// </summary>
void UpdateOrientationObjects()
{
m_OrientationCube.UpdateOrientation(body, m_Target);
if (m_DirectionIndicator)
{
m_DirectionIndicator.MatchOrientation(m_OrientationCube.transform);
}
}
/// <summary>
///Returns the average velocity of all of the body parts
///Using the velocity of the body only has shown to result in more erratic movement from the limbs
///Using the average helps prevent this erratic movement
/// </summary>
Vector3 GetAvgVelocity()
{
Vector3 velSum = Vector3.zero;
Vector3 avgVel = Vector3.zero;
//ALL RBS
int numOfRb = 0;
foreach (var item in m_JdController.bodyPartsList)
{
numOfRb++;
velSum += item.rb.velocity;
}
avgVel = velSum / numOfRb;
return avgVel;
}
/// <summary>
/// Normalized value of the difference in actual speed vs goal walking speed.
/// </summary>
public float GetMatchingVelocityReward(Vector3 velocityGoal, Vector3 actualVelocity)
{
//distance between our actual velocity and goal velocity
var velDeltaMagnitude = Mathf.Clamp(Vector3.Distance(actualVelocity, velocityGoal), 0, TargetWalkingSpeed);
//return the value on a declining sigmoid shaped curve that decays from 1 to 0
//This reward will approach 1 if it matches perfectly and approach zero as it deviates
return Mathf.Pow(1 - Mathf.Pow(velDeltaMagnitude / TargetWalkingSpeed, 2), 2);
}
/// <summary>
/// Agent touched the target
/// </summary>
public void TouchedTarget()
{
AddReward(1f);
}
}
|