|
using System.Collections; |
|
using UnityEngine; |
|
using Unity.MLAgents; |
|
using Unity.MLAgents.Actuators; |
|
using Unity.MLAgents.Sensors; |
|
|
|
public class HallwayAgent : Agent |
|
{ |
|
public GameObject ground; |
|
public GameObject area; |
|
public GameObject symbolOGoal; |
|
public GameObject symbolXGoal; |
|
public GameObject symbolO; |
|
public GameObject symbolX; |
|
public bool useVectorObs; |
|
Rigidbody m_AgentRb; |
|
Material m_GroundMaterial; |
|
Renderer m_GroundRenderer; |
|
HallwaySettings m_HallwaySettings; |
|
int m_Selection; |
|
StatsRecorder m_statsRecorder; |
|
|
|
public override void Initialize() |
|
{ |
|
m_HallwaySettings = FindObjectOfType<HallwaySettings>(); |
|
m_AgentRb = GetComponent<Rigidbody>(); |
|
m_GroundRenderer = ground.GetComponent<Renderer>(); |
|
m_GroundMaterial = m_GroundRenderer.material; |
|
m_statsRecorder = Academy.Instance.StatsRecorder; |
|
} |
|
|
|
public override void CollectObservations(VectorSensor sensor) |
|
{ |
|
if (useVectorObs) |
|
{ |
|
sensor.AddObservation(StepCount / (float)MaxStep); |
|
} |
|
} |
|
|
|
IEnumerator GoalScoredSwapGroundMaterial(Material mat, float time) |
|
{ |
|
m_GroundRenderer.material = mat; |
|
yield return new WaitForSeconds(time); |
|
m_GroundRenderer.material = m_GroundMaterial; |
|
} |
|
|
|
public void MoveAgent(ActionSegment<int> act) |
|
{ |
|
var dirToGo = Vector3.zero; |
|
var rotateDir = Vector3.zero; |
|
|
|
var action = act[0]; |
|
switch (action) |
|
{ |
|
case 1: |
|
dirToGo = transform.forward * 1f; |
|
break; |
|
case 2: |
|
dirToGo = transform.forward * -1f; |
|
break; |
|
case 3: |
|
rotateDir = transform.up * 1f; |
|
break; |
|
case 4: |
|
rotateDir = transform.up * -1f; |
|
break; |
|
} |
|
transform.Rotate(rotateDir, Time.deltaTime * 150f); |
|
m_AgentRb.AddForce(dirToGo * m_HallwaySettings.agentRunSpeed, ForceMode.VelocityChange); |
|
} |
|
|
|
public override void OnActionReceived(ActionBuffers actionBuffers) |
|
|
|
{ |
|
AddReward(-1f / MaxStep); |
|
MoveAgent(actionBuffers.DiscreteActions); |
|
} |
|
|
|
void OnCollisionEnter(Collision col) |
|
{ |
|
if (col.gameObject.CompareTag("symbol_O_Goal") || col.gameObject.CompareTag("symbol_X_Goal")) |
|
{ |
|
if ((m_Selection == 0 && col.gameObject.CompareTag("symbol_O_Goal")) || |
|
(m_Selection == 1 && col.gameObject.CompareTag("symbol_X_Goal"))) |
|
{ |
|
SetReward(1f); |
|
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.goalScoredMaterial, 0.5f)); |
|
m_statsRecorder.Add("Goal/Correct", 1, StatAggregationMethod.Sum); |
|
} |
|
else |
|
{ |
|
SetReward(-0.1f); |
|
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f)); |
|
m_statsRecorder.Add("Goal/Wrong", 1, StatAggregationMethod.Sum); |
|
} |
|
EndEpisode(); |
|
} |
|
} |
|
|
|
public override void Heuristic(in ActionBuffers actionsOut) |
|
{ |
|
var discreteActionsOut = actionsOut.DiscreteActions; |
|
if (Input.GetKey(KeyCode.D)) |
|
{ |
|
discreteActionsOut[0] = 3; |
|
} |
|
else if (Input.GetKey(KeyCode.W)) |
|
{ |
|
discreteActionsOut[0] = 1; |
|
} |
|
else if (Input.GetKey(KeyCode.A)) |
|
{ |
|
discreteActionsOut[0] = 4; |
|
} |
|
else if (Input.GetKey(KeyCode.S)) |
|
{ |
|
discreteActionsOut[0] = 2; |
|
} |
|
} |
|
|
|
public override void OnEpisodeBegin() |
|
{ |
|
var agentOffset = -15f; |
|
var blockOffset = 0f; |
|
m_Selection = Random.Range(0, 2); |
|
if (m_Selection == 0) |
|
{ |
|
symbolO.transform.position = |
|
new Vector3(0f + Random.Range(-3f, 3f), 2f, blockOffset + Random.Range(-5f, 5f)) |
|
+ ground.transform.position; |
|
symbolX.transform.position = |
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f)) |
|
+ ground.transform.position; |
|
} |
|
else |
|
{ |
|
symbolO.transform.position = |
|
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f)) |
|
+ ground.transform.position; |
|
symbolX.transform.position = |
|
new Vector3(0f, 2f, blockOffset + Random.Range(-5f, 5f)) |
|
+ ground.transform.position; |
|
} |
|
|
|
transform.position = new Vector3(0f + Random.Range(-3f, 3f), |
|
1f, agentOffset + Random.Range(-5f, 5f)) |
|
+ ground.transform.position; |
|
transform.rotation = Quaternion.Euler(0f, Random.Range(0f, 360f), 0f); |
|
m_AgentRb.velocity *= 0f; |
|
|
|
var goalPos = Random.Range(0, 2); |
|
if (goalPos == 0) |
|
{ |
|
symbolOGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position; |
|
symbolXGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position; |
|
} |
|
else |
|
{ |
|
symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position; |
|
symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position; |
|
} |
|
m_statsRecorder.Add("Goal/Correct", 0, StatAggregationMethod.Sum); |
|
m_statsRecorder.Add("Goal/Wrong", 0, StatAggregationMethod.Sum); |
|
} |
|
} |
|
|