Recently I got into ml-agent, however I encounter a problem, that is my agent object(blue) and target object(red) will overlap and break the simulation. I've tried to search for solutions but most of them suggest just avoid it manually(by coordinate) or object range. However, I want to train the model to be able to deal with random target rotation and target that spawned right beside the agent, which the two solution above will prevent the scene from generating.
(The rolling ball example for starter that was provided by ml-agent git doesn't have overlap prevention either, but I successfully trained the model)
Down below is my unity scene and code:
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class forklift : Agent
{
Rigidbody rBody;
void Start () {
rBody = GetComponent<Rigidbody>();
}
public Transform Target;
public override void OnEpisodeBegin()
{
// If the Agent fell, zero its momentum
if (this.transform.localPosition.y < 0)
{
this.rBody.angularVelocity = Vector3.zero;
this.rBody.velocity = Vector3.zero;
this.transform.localPosition = new Vector3( 0, 0f, 0);
}
// Move the target to a new spot (between(17,17)&(-17,-17),leaving width 8 free space)
Target.localPosition = new Vector3(Random.value * 34 - 17,
0f,
Random.value * 34 - 17);
//Target.transform.localRotation = Quaternion.Euler(0.0f, Random.value *360, 0.0f);
}
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(Target.localPosition);
sensor.AddObservation(this.transform.localPosition);
// Agent velocity
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
public float forceMultiplier = 10;
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Actions, size = 2
Vector3 controlSignal = Vector3.zero;
controlSignal.x = actionBuffers.ContinuousActions[0];
controlSignal.z = actionBuffers.ContinuousActions[1];
rBody.AddForce(controlSignal * forceMultiplier);
// Rewards
float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
// Reached target
if (distanceToTarget < 1.42f)
{
SetReward(1.0f);
EndEpisode();
}
// Fell off platform
else if (this.transform.localPosition.y < 0)
{
EndEpisode();
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetAxis("Vertical");
}
}