|
using System; |
|
using UnityEngine; |
|
using System.Linq; |
|
using Unity.MLAgents; |
|
using Unity.MLAgents.Sensors; |
|
using Unity.MLAgents.Actuators; |
|
using UnityEngine.Rendering; |
|
using UnityEngine.Serialization; |
|
|
|
public class GridAgent : Agent |
|
{ |
|
[FormerlySerializedAs("m_Area")] |
|
[Header("Specific to GridWorld")] |
|
public GridArea area; |
|
public float timeBetweenDecisionsAtInference; |
|
float m_TimeSinceDecision; |
|
|
|
[Tooltip("Because we want an observation right before making a decision, we can force " + |
|
"a camera to render before making a decision. Place the agentCam here if using " + |
|
"RenderTexture as observations.")] |
|
public Camera renderCamera; |
|
|
|
VectorSensorComponent m_GoalSensor; |
|
|
|
public enum GridGoal |
|
{ |
|
GreenPlus, |
|
RedEx, |
|
} |
|
|
|
|
|
|
|
|
|
public GameObject GreenBottom; |
|
public GameObject RedBottom; |
|
|
|
GridGoal m_CurrentGoal; |
|
|
|
public GridGoal CurrentGoal |
|
{ |
|
get { return m_CurrentGoal; } |
|
set |
|
{ |
|
switch (value) |
|
{ |
|
case GridGoal.GreenPlus: |
|
GreenBottom.SetActive(true); |
|
RedBottom.SetActive(false); |
|
break; |
|
case GridGoal.RedEx: |
|
GreenBottom.SetActive(false); |
|
RedBottom.SetActive(true); |
|
break; |
|
} |
|
m_CurrentGoal = value; |
|
} |
|
} |
|
|
|
[Tooltip("Selecting will turn on action masking. Note that a model trained with action " + |
|
"masking turned on may not behave optimally when action masking is turned off.")] |
|
public bool maskActions = true; |
|
|
|
const int k_NoAction = 0; |
|
const int k_Up = 1; |
|
const int k_Down = 2; |
|
const int k_Left = 3; |
|
const int k_Right = 4; |
|
|
|
EnvironmentParameters m_ResetParams; |
|
|
|
public override void Initialize() |
|
{ |
|
m_GoalSensor = this.GetComponent<VectorSensorComponent>(); |
|
m_ResetParams = Academy.Instance.EnvironmentParameters; |
|
} |
|
|
|
public override void CollectObservations(VectorSensor sensor) |
|
{ |
|
Array values = Enum.GetValues(typeof(GridGoal)); |
|
|
|
if (m_GoalSensor is object) |
|
{ |
|
int goalNum = (int)CurrentGoal; |
|
m_GoalSensor.GetSensor().AddOneHotObservation(goalNum, values.Length); |
|
} |
|
} |
|
|
|
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask) |
|
{ |
|
|
|
if (maskActions) |
|
{ |
|
|
|
var positionX = (int)transform.localPosition.x; |
|
var positionZ = (int)transform.localPosition.z; |
|
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1; |
|
|
|
if (positionX == 0) |
|
{ |
|
actionMask.SetActionEnabled(0, k_Left, false); |
|
} |
|
|
|
if (positionX == maxPosition) |
|
{ |
|
actionMask.SetActionEnabled(0, k_Right, false); |
|
} |
|
|
|
if (positionZ == 0) |
|
{ |
|
actionMask.SetActionEnabled(0, k_Down, false); |
|
} |
|
|
|
if (positionZ == maxPosition) |
|
{ |
|
actionMask.SetActionEnabled(0, k_Up, false); |
|
} |
|
} |
|
} |
|
|
|
|
|
public override void OnActionReceived(ActionBuffers actionBuffers) |
|
|
|
{ |
|
AddReward(-0.01f); |
|
var action = actionBuffers.DiscreteActions[0]; |
|
|
|
var targetPos = transform.position; |
|
switch (action) |
|
{ |
|
case k_NoAction: |
|
|
|
break; |
|
case k_Right: |
|
targetPos = transform.position + new Vector3(1f, 0, 0f); |
|
break; |
|
case k_Left: |
|
targetPos = transform.position + new Vector3(-1f, 0, 0f); |
|
break; |
|
case k_Up: |
|
targetPos = transform.position + new Vector3(0f, 0, 1f); |
|
break; |
|
case k_Down: |
|
targetPos = transform.position + new Vector3(0f, 0, -1f); |
|
break; |
|
default: |
|
throw new ArgumentException("Invalid action value"); |
|
} |
|
|
|
var hit = Physics.OverlapBox( |
|
targetPos, new Vector3(0.3f, 0.3f, 0.3f)); |
|
if (hit.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0) |
|
{ |
|
transform.position = targetPos; |
|
|
|
if (hit.Where(col => col.gameObject.CompareTag("plus")).ToArray().Length == 1) |
|
{ |
|
ProvideReward(GridGoal.GreenPlus); |
|
EndEpisode(); |
|
} |
|
else if (hit.Where(col => col.gameObject.CompareTag("ex")).ToArray().Length == 1) |
|
{ |
|
ProvideReward(GridGoal.RedEx); |
|
EndEpisode(); |
|
} |
|
} |
|
} |
|
|
|
private void ProvideReward(GridGoal hitObject) |
|
{ |
|
if (CurrentGoal == hitObject) |
|
{ |
|
SetReward(1f); |
|
} |
|
else |
|
{ |
|
SetReward(-1f); |
|
} |
|
} |
|
|
|
public override void Heuristic(in ActionBuffers actionsOut) |
|
{ |
|
var discreteActionsOut = actionsOut.DiscreteActions; |
|
discreteActionsOut[0] = k_NoAction; |
|
if (Input.GetKey(KeyCode.D)) |
|
{ |
|
discreteActionsOut[0] = k_Right; |
|
} |
|
if (Input.GetKey(KeyCode.W)) |
|
{ |
|
discreteActionsOut[0] = k_Up; |
|
} |
|
if (Input.GetKey(KeyCode.A)) |
|
{ |
|
discreteActionsOut[0] = k_Left; |
|
} |
|
if (Input.GetKey(KeyCode.S)) |
|
{ |
|
discreteActionsOut[0] = k_Down; |
|
} |
|
} |
|
|
|
|
|
public override void OnEpisodeBegin() |
|
{ |
|
area.AreaReset(); |
|
Array values = Enum.GetValues(typeof(GridGoal)); |
|
if (m_GoalSensor is object) |
|
{ |
|
CurrentGoal = (GridGoal)values.GetValue(UnityEngine.Random.Range(0, values.Length)); |
|
} |
|
else |
|
{ |
|
CurrentGoal = GridGoal.GreenPlus; |
|
} |
|
} |
|
|
|
public void FixedUpdate() |
|
{ |
|
WaitTimeInference(); |
|
} |
|
|
|
void WaitTimeInference() |
|
{ |
|
if (renderCamera != null && SystemInfo.graphicsDeviceType != GraphicsDeviceType.Null) |
|
{ |
|
renderCamera.Render(); |
|
} |
|
|
|
if (Academy.Instance.IsCommunicatorOn) |
|
{ |
|
RequestDecision(); |
|
} |
|
else |
|
{ |
|
if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference) |
|
{ |
|
m_TimeSinceDecision = 0f; |
|
RequestDecision(); |
|
} |
|
else |
|
{ |
|
m_TimeSinceDecision += Time.fixedDeltaTime; |
|
} |
|
} |
|
} |
|
} |
|
|