1

I've been working on a pair of legs that self-balance. If his 'waist' goes below a certain y-position value (falling over/tripping), the area is supposed to reset and also deduct points from his reward-score. I'm awfully new to machine learning, so go easy on me! Why is the agent not resetting when he falls over?

Legs trainer resport Agents in inspector




Code to Agent (Updated):

    using MLAgents;
    using System;
    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;

    using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    //public GameObject goal;

    // private float buttR = 0f;

    public GameObject[] bodyParts = new GameObject[9];
    public Vector3[] posStart = new Vector3[9];
    public Vector3[] eulerStart = new Vector3[9];



    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        bodyParts = new GameObject[]{waist, buttR, buttL, thighR, thighL, legR, legL, footR, footL};

        for(int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
        }

    }

    public override void AgentReset() {
        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
        }
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);

        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            print("reset!");
            AddReward(-.1f);
            Done();
        }

        public override void CollectObservations() {
            AddVectorObs(waist.transform.localEulerAngles.y);
            AddVectorObs(buttR.transform.localEulerAngles.x);
            AddVectorObs(buttL.transform.localEulerAngles.x);
            AddVectorObs(thighR.transform.localEulerAngles.y);
            AddVectorObs(thighL.transform.localEulerAngles.y);
            AddVectorObs(legR.transform.localEulerAngles.y);
            AddVectorObs(legL.transform.localEulerAngles.y);
            AddVectorObs(footR.transform.localEulerAngles.y);
            AddVectorObs(footL.transform.localEulerAngles.y);
            AddVectorObs(waist.transform.position);
        }
    }




Code to Area:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;

public class BalancingArea : Area
{
    public List<BalanceAgent> BalanceAgent { get; private set; }
    public BalanceAcademy BalanceAcademy { get; private set; }
    public GameObject area;

    private void Awake() {
        BalanceAgent = transform.GetComponentsInChildren<BalanceAgent>().ToList();              //Grabs all agents in area
        BalanceAcademy = FindObjectOfType<BalanceAcademy>();                //Grabs balance acedem
    }

    private void Start() {

    }

    public void ResetAgentPosition(BalanceAgent agent) {
        agent.transform.position = new Vector3(area.transform.position.x, 0, area.transform.position.z);
        agent.transform.eulerAngles = new Vector3(0,0,0);
    }

    // Update is called once per frame
    void Update()
    {

    }
}




Code to BalanceAcademy:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAcademy : Academy
{

}



Command used to run trainer:

mlagents-learn config/trainer_config.yaml --run-id=balancetest09 --train
  • Please include the definition for `BalanceAcademy` – Ruzihm Dec 12 '19 at 17:11
  • @Ruzihm I've watched videos where they don't put much in the BalanceAcademy script. Actually, I'm not sure if I have it attached to any game object. Does it need to be? Should it be attached to the area object? Either way, it's now added to the post. Thanks. – Jaden Williams Dec 12 '19 at 18:05
  • Thanks for the update. Don't forget to accept an answer if [someone answers](https://stackoverflow.com/help/someone-answers) your question! – Ruzihm Dec 12 '19 at 18:31
  • Do all of the body parts have rigidbodies on them? – Ruzihm Dec 12 '19 at 20:32
  • @Ruzihm Yes, they do. Why would that matter? – Jaden Williams Dec 12 '19 at 20:35
  • Because rigidbodies contain velocity & angular velocity information that you would also want to reset. – Ruzihm Dec 12 '19 at 20:36
  • How did you pick `waist.transform.position.y <= -3` as your reset condition? What is the starting **world** position of the waist? **This may be different from what is shown in the inspector!** You should be able to see it in the Unity log by adding `Debug.Log(waist.transform.position);` at the end of `InitializeAgent` – Ruzihm Dec 12 '19 at 21:28
  • @Ruzihm So funny thing.. Even when I comment-out the Done() call, it still keeps resetting! I tried to also have something print right before AND after the method is called, and the print statements never showed up in the logs. – Jaden Williams Dec 12 '19 at 21:34
  • Bizarre. Let us [continue this discussion in chat](https://chat.stackoverflow.com/rooms/204150/discussion-between-ruzihm-and-jaden-williams). – Ruzihm Dec 12 '19 at 21:36

1 Answers1

1

From the documentation on creating a new environment:

Initialization and Resetting the Agent

When the Agent reaches its target, it marks itself done and its Agent reset function moves the target to a random location. In addition, if the Agent rolls off the platform, the reset function puts it back onto the floor.

To move the target GameObject, we need a reference to its Transform (which stores a GameObject's position, orientation and scale in the 3D world). To get this reference, add a public field of type Transform to the RollerAgent class. Public fields of a component in Unity get displayed in the Inspector window, allowing you to choose which GameObject to use as the target in the Unity Editor.

To reset the Agent's velocity (and later to apply force to move the agent) we need a reference to the Rigidbody component. A Rigidbody is Unity's primary element for physics simulation. (See Physics for full documentation of Unity physics.) Since the Rigidbody component is on the same GameObject as our Agent script, the best way to get this reference is using GameObject.GetComponent<T>(), which we can call in our script's Start() method.

So far, our RollerAgent script looks like:

using System.Collections.Generic;
using UnityEngine;
using MLAgents;

public class RollerAgent : Agent
{
    Rigidbody rBody;
    void Start () {
        rBody = GetComponent<Rigidbody>();
    }

    public Transform Target;
    public override void AgentReset()
    {
        if (this.transform.position.y < 0)
        {
            // If the Agent fell, zero its momentum
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.position = new Vector3( 0, 0.5f, 0);
        }

        // Move the target to a new spot
        Target.position = new Vector3(Random.value * 8 - 4,
                                      0.5f,
                                      Random.value * 8 - 4);
    }
}

So, you should override AgentReset method so that that will reset the position of the agent's joints. To get you started, you could take the rotation and position of each of the joints in InitializeAgent, and then restore them in AgentReset. Also, zero out the velocity and angular velocity of the rigidbodies.

I don't see anything in the documentation or examples about calling Done in Update, so it may be recommended or even required for it to be in AgentAction to behave as expected. Might as well move everything out of Update and into AgentAction.

Also, you may want to use transform.localEulerAngles in your feature vector, which has 3 components, (xyz) instead of transform.localRotation, which has 4 components (xyzw). Otherwise, you should not omit the w component of localRotation.

Altogether, it might look like this:

using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;
    public GameObject goal;

    private List<GameObject> gameObjectsToReset;
    private List<Rigidbody> rigidbodiesToReset;
    private List<Vector3> initEulers;
    private List<Vector3> initPositions;

    // private float buttR = 0f;


    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        gameObjectsToReset= new List<GameObject>(new GameObject[]{
                waist, buttR, buttL, thighR, thighL, legR, legL,
                footR, footL});
        rigidbodiesToReset = new List<Rigidbody>();
        initEulers = new List<Vector3>();
        initPositions = new List<Vector3>();

        foreach (GameObject g in gameObjectsToReset) {
            rigidbodiesToReset.Add(g.GetComponent<Rigidbody>());
            initEulers.Add(g.transform.eulerAngles);
            initPositions.Add(g.transform.position);
        }
    }

    public override void AgentReset() {
        for (int i = 0 ; i < gameObjectsToReset.Count ; i++) {
            Transform t = gameObjectsToReset[i].transform;
            t.position = initPositions[i];
            t.eulerAngles = initEulers[i];

            Rigidbody r = rigidbodiesToReset[i];
            r.velocity = Vector3.zero;
            r.angularVelocity = Vector3.zero;
        } 
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);



        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1.3) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            Done();
            AddReward(-.1f);
        }
    }

    public override void CollectObservations() {
        AddVectorObs(waist.transform.localEulerAngles.y);
        AddVectorObs(buttR.transform.localEulerAngles.x);
        AddVectorObs(buttL.transform.localEulerAngles.x);
        AddVectorObs(thighR.transform.localEulerAngles.y);
        AddVectorObs(thighL.transform.localEulerAngles.y);
        AddVectorObs(legR.transform.localEulerAngles.y);
        AddVectorObs(legL.transform.localEulerAngles.y);
        AddVectorObs(footR.transform.localEulerAngles.y);
        AddVectorObs(footL.transform.localEulerAngles.y);

        AddVectorObs(waist.GetComponent<Rigidbody>().freezeRotation);

        AddVectorObs(waist.transform.position);
    }
}

Finally, make sure you set your BalanceAgent's Max Step to something large enough to see if the agent will fail, maybe 500 or 1000 for starters.

<code>Max Step</code> is editable in the inspector

Ruzihm
  • 19,749
  • 5
  • 36
  • 48
  • 1
    So now he's resetting (thanks), but it seems like it's happening every half a second! The only time I want him to reset and for a new 'episode' to start is when his waist falls below -3 on his y position. I thought the Done() method is used for that. Thanks. – Jaden Williams Dec 12 '19 at 20:11
  • @JadenWilliams I added a part about resetting the physics. Note the new `List` and the new parts in `InitializeAgent` and `AgentReset` – Ruzihm Dec 12 '19 at 20:35
  • @JadenWilliams Did my suggestion not change anything? – Ruzihm Dec 12 '19 at 20:40
  • @JadenWilliams Just fixed a typo, `rigidbodysToReset` -> `rigidbodiesToReset`. I have no idea how to reproduce your scene so I can only guess that half a second is how long it takes for the waist to go below y=-3, and that may be because the downward velocity from the previous episode(s) aren't reset. – Ruzihm Dec 12 '19 at 20:45
  • As you've already said, `Done` resets the agent. Try moving the code out of `Update` and into `AgentAction` – Ruzihm Dec 12 '19 at 21:09
  • @JadenWilliams Rigidbodies were relevant because if the velocity isn't reset, the agent will fall faster with each episode. When we get this working, you can comment out the velocity reset to see what I mean. – Ruzihm Dec 12 '19 at 21:31