Created a simple game in unity where ball should hit the targets without hitting the walls. So, started training and the results was too bad. The ball is just collecting one of the 4 targets. But EndEpisode() happens when it collects the last target.
The ball doesn't even tries to hit the second target. What is wrong with my code?
I've even tried with RayPerceptionSensor3D replacing the sphere with a cylinder so that it doesn't roll over and disturb the rayperceptionSensor3d. But it gives even much worse results.
using System.Security.Cryptography;
using System.Data.SqlTypes;
using System.Security;
using System.Runtime.InteropServices;
using System.Net.Sockets;
using System.ComponentModel.Design.Serialization;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using MLAgents.Sensors;
using TMPro;
public class MazeRoller : Agent
{
Rigidbody rBody;
Vector3 ballpos;
void Start () {
rBody = GetComponent<Rigidbody>();
ballpos = rBody.transform.position;
}
public TextMeshPro text;
public TextMeshPro miss;
public TextMeshPro hit;
int count=0,c=0,h=0,m=0;
int boxescollect=0;
public Transform Target;
public Transform st1;
public Transform st2;
public Transform st3;
public override void OnEpisodeBegin()
{
rBody.angularVelocity = Vector3.zero;
rBody.velocity = Vector3.zero;
rBody.transform.position = ballpos;
boxescollect=0;
st1.GetComponent<Renderer> ().enabled = true;
st1.GetComponent<Collider> ().enabled = true;
st2.GetComponent<Renderer> ().enabled = true;
st2.GetComponent<Collider> ().enabled = true;
st3.GetComponent<Renderer> ().enabled = true;
st3.GetComponent<Collider> ().enabled = true;
}
void OnCollisionEnter(Collision collision)
{
if(collision.gameObject.name == "Target")
{
if(st1.GetComponent<Renderer> ().enabled==true || st2.GetComponent<Renderer> ().enabled==true || st3.GetComponent<Renderer> ().enabled==true)
{
SetReward(-3.0f+(float)(boxescollect));
}
SetReward(2.0f);
h++;
hit.SetText(h+"");
EndEpisode();
}
else if(collision.gameObject.name == "Target1")
{
boxescollect++;
AddReward(0.2f);
st1.GetComponent<Renderer> ().enabled = false;
st1.GetComponent<Collider> ().enabled = false;
}
else if(collision.gameObject.name == "Target2")
{
boxescollect++;
AddReward(0.4f);
st2.GetComponent<Renderer> ().enabled = false;
st2.GetComponent<Collider> ().enabled = false;
}
else if(collision.gameObject.name == "Target3")
{
boxescollect++;
AddReward(0.6f);
st3.GetComponent<Renderer> ().enabled = false;
st3.GetComponent<Collider> ().enabled = false;
}
//collision.gameObject.name == "wall1"||collision.gameObject.name == "wall2"||collision.gameObject.name == "wall3"||collision.gameObject.name == "wall4"||collision.gameObject.name == "wall5"||collision.gameObject.name == "wall6"||collision.gameObject.name == "wall7"
else if(collision.gameObject.tag == "wall")
{
if(st1.GetComponent<Renderer> ().enabled==true || st2.GetComponent<Renderer> ().enabled==true || st3.GetComponent<Renderer> ().enabled==true)
{
AddReward(-3.0f+(float)(boxescollect));
}
SetReward(-1.0f);
m++;
miss.SetText(m+"");
EndEpisode();
}
}
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(Target.position);
sensor.AddObservation(this.transform.position);
sensor.AddObservation(boxescollect);
sensor.AddObservation(boxescollect-3);
sensor.AddObservation(st1.position);
sensor.AddObservation(st2.position);
sensor.AddObservation(st3.position);
float dist = Vector3.Distance(Target.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(dist);
float d1 = Vector3.Distance(st1.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d1);
float d2 = Vector3.Distance(st2.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d2);
float d3 = Vector3.Distance(st3.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d3);
// Agent velocity
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
public float speed = 10;
public override void OnActionReceived(float[] vectorAction)
{
Vector3 controlSignal = Vector3.zero;
controlSignal.x = vectorAction[0];
controlSignal.z = vectorAction[1];
//speed = vectorAction[2];
rBody.AddForce(controlSignal * speed);
//speed=0;
count++;
if(count==10000)
{
count=0;
h=0;
m=0;
c++;
miss.SetText(m+"");
hit.SetText(h+"");
text.SetText(c+"");
}
}
public override float[] Heuristic()
{
var action = new float[2];
action[0] = Input.GetAxis("Horizontal");
action[1] = Input.GetAxis("Vertical");
return action;
}
}
weird Graph of the training - tensorboard This is what I get after the training in tensorboard.