|
using System; |
|
using UnityEngine; |
|
using Unity.MLAgents; |
|
|
|
namespace Unity.MLAgentsExamples |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum State |
|
{ |
|
|
|
|
|
|
|
Invalid = -1, |
|
|
|
|
|
|
|
|
|
FindMatches = 0, |
|
|
|
|
|
|
|
|
|
ClearMatched = 1, |
|
|
|
|
|
|
|
|
|
Drop = 2, |
|
|
|
|
|
|
|
|
|
FillEmpty = 3, |
|
|
|
|
|
|
|
|
|
WaitForMove = 4, |
|
} |
|
|
|
public class Match3Agent : Agent |
|
{ |
|
[HideInInspector] |
|
public Match3Board Board; |
|
|
|
public float MoveTime = 1.0f; |
|
public int MaxMoves = 500; |
|
|
|
|
|
State m_CurrentState = State.WaitForMove; |
|
float m_TimeUntilMove; |
|
private int m_MovesMade; |
|
private ModelOverrider m_ModelOverrider; |
|
|
|
private const float k_RewardMultiplier = 0.01f; |
|
protected override void Awake() |
|
{ |
|
base.Awake(); |
|
Board = GetComponent<Match3Board>(); |
|
m_ModelOverrider = GetComponent<ModelOverrider>(); |
|
} |
|
|
|
public override void OnEpisodeBegin() |
|
{ |
|
base.OnEpisodeBegin(); |
|
|
|
Board.UpdateCurrentBoardSize(); |
|
Board.InitSettled(); |
|
m_CurrentState = State.FindMatches; |
|
m_TimeUntilMove = MoveTime; |
|
m_MovesMade = 0; |
|
} |
|
|
|
private void FixedUpdate() |
|
{ |
|
|
|
var useFast = Academy.Instance.IsCommunicatorOn || (m_ModelOverrider != null && m_ModelOverrider.HasOverrides); |
|
if (useFast) |
|
{ |
|
FastUpdate(); |
|
} |
|
else |
|
{ |
|
AnimatedUpdate(); |
|
} |
|
|
|
|
|
|
|
|
|
if (m_MovesMade >= MaxMoves) |
|
{ |
|
EpisodeInterrupted(); |
|
} |
|
} |
|
|
|
void FastUpdate() |
|
{ |
|
while (true) |
|
{ |
|
var hasMatched = Board.MarkMatchedCells(); |
|
if (!hasMatched) |
|
{ |
|
break; |
|
} |
|
var pointsEarned = Board.ClearMatchedCells(); |
|
AddReward(k_RewardMultiplier * pointsEarned); |
|
Board.DropCells(); |
|
Board.FillFromAbove(); |
|
} |
|
|
|
while (!HasValidMoves()) |
|
{ |
|
|
|
Board.InitSettled(); |
|
} |
|
RequestDecision(); |
|
m_MovesMade++; |
|
} |
|
|
|
void AnimatedUpdate() |
|
{ |
|
m_TimeUntilMove -= Time.deltaTime; |
|
if (m_TimeUntilMove > 0.0f) |
|
{ |
|
return; |
|
} |
|
|
|
m_TimeUntilMove = MoveTime; |
|
|
|
State nextState; |
|
switch (m_CurrentState) |
|
{ |
|
case State.FindMatches: |
|
var hasMatched = Board.MarkMatchedCells(); |
|
nextState = hasMatched ? State.ClearMatched : State.WaitForMove; |
|
if (nextState == State.WaitForMove) |
|
{ |
|
m_MovesMade++; |
|
} |
|
break; |
|
case State.ClearMatched: |
|
var pointsEarned = Board.ClearMatchedCells(); |
|
AddReward(k_RewardMultiplier * pointsEarned); |
|
nextState = State.Drop; |
|
break; |
|
case State.Drop: |
|
Board.DropCells(); |
|
nextState = State.FillEmpty; |
|
break; |
|
case State.FillEmpty: |
|
Board.FillFromAbove(); |
|
nextState = State.FindMatches; |
|
break; |
|
case State.WaitForMove: |
|
while (true) |
|
{ |
|
|
|
bool hasMoves = HasValidMoves(); |
|
if (hasMoves) |
|
{ |
|
break; |
|
} |
|
Board.InitSettled(); |
|
} |
|
RequestDecision(); |
|
|
|
nextState = State.FindMatches; |
|
break; |
|
default: |
|
throw new ArgumentOutOfRangeException(); |
|
} |
|
|
|
m_CurrentState = nextState; |
|
} |
|
|
|
bool HasValidMoves() |
|
{ |
|
foreach (var unused in Board.ValidMoves()) |
|
{ |
|
return true; |
|
} |
|
|
|
return false; |
|
} |
|
|
|
} |
|
|
|
} |
|
|