Search code examples
c#unity-game-enginemachine-learninggame-physicsml-agent

ML-Agents agent not resetting?


I've been working on a pair of legs that self-balance. If his 'waist' goes below a certain y-position value (falling over/tripping), the area is supposed to reset and also deduct points from his reward-score. I'm awfully new to machine learning, so go easy on me! Why is the agent not resetting when he falls over?

Legs trainer resport Agents in inspector




Code to Agent (Updated):

    using MLAgents;
    using System;
    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;

    using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    //public GameObject goal;

    // private float buttR = 0f;

    public GameObject[] bodyParts = new GameObject[9];
    public Vector3[] posStart = new Vector3[9];
    public Vector3[] eulerStart = new Vector3[9];



    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        bodyParts = new GameObject[]{waist, buttR, buttL, thighR, thighL, legR, legL, footR, footL};

        for(int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
        }

    }

    public override void AgentReset() {
        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
        }
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);

        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            print("reset!");
            AddReward(-.1f);
            Done();
        }

        public override void CollectObservations() {
            AddVectorObs(waist.transform.localEulerAngles.y);
            AddVectorObs(buttR.transform.localEulerAngles.x);
            AddVectorObs(buttL.transform.localEulerAngles.x);
            AddVectorObs(thighR.transform.localEulerAngles.y);
            AddVectorObs(thighL.transform.localEulerAngles.y);
            AddVectorObs(legR.transform.localEulerAngles.y);
            AddVectorObs(legL.transform.localEulerAngles.y);
            AddVectorObs(footR.transform.localEulerAngles.y);
            AddVectorObs(footL.transform.localEulerAngles.y);
            AddVectorObs(waist.transform.position);
        }
    }




Code to Area:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;

public class BalancingArea : Area
{
    public List<BalanceAgent> BalanceAgent { get; private set; }
    public BalanceAcademy BalanceAcademy { get; private set; }
    public GameObject area;

    private void Awake() {
        BalanceAgent = transform.GetComponentsInChildren<BalanceAgent>().ToList();              //Grabs all agents in area
        BalanceAcademy = FindObjectOfType<BalanceAcademy>();                //Grabs balance acedem
    }

    private void Start() {

    }

    public void ResetAgentPosition(BalanceAgent agent) {
        agent.transform.position = new Vector3(area.transform.position.x, 0, area.transform.position.z);
        agent.transform.eulerAngles = new Vector3(0,0,0);
    }

    // Update is called once per frame
    void Update()
    {

    }
}




Code to BalanceAcademy:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAcademy : Academy
{

}



Command used to run trainer:

mlagents-learn config/trainer_config.yaml --run-id=balancetest09 --train

Solution

  • From the documentation on creating a new environment:

    Initialization and Resetting the Agent

    When the Agent reaches its target, it marks itself done and its Agent reset function moves the target to a random location. In addition, if the Agent rolls off the platform, the reset function puts it back onto the floor.

    To move the target GameObject, we need a reference to its Transform (which stores a GameObject's position, orientation and scale in the 3D world). To get this reference, add a public field of type Transform to the RollerAgent class. Public fields of a component in Unity get displayed in the Inspector window, allowing you to choose which GameObject to use as the target in the Unity Editor.

    To reset the Agent's velocity (and later to apply force to move the agent) we need a reference to the Rigidbody component. A Rigidbody is Unity's primary element for physics simulation. (See Physics for full documentation of Unity physics.) Since the Rigidbody component is on the same GameObject as our Agent script, the best way to get this reference is using GameObject.GetComponent<T>(), which we can call in our script's Start() method.

    So far, our RollerAgent script looks like:

    using System.Collections.Generic;
    using UnityEngine;
    using MLAgents;
    
    public class RollerAgent : Agent
    {
        Rigidbody rBody;
        void Start () {
            rBody = GetComponent<Rigidbody>();
        }
    
        public Transform Target;
        public override void AgentReset()
        {
            if (this.transform.position.y < 0)
            {
                // If the Agent fell, zero its momentum
                this.rBody.angularVelocity = Vector3.zero;
                this.rBody.velocity = Vector3.zero;
                this.transform.position = new Vector3( 0, 0.5f, 0);
            }
    
            // Move the target to a new spot
            Target.position = new Vector3(Random.value * 8 - 4,
                                          0.5f,
                                          Random.value * 8 - 4);
        }
    }
    

    So, you should override AgentReset method so that that will reset the position of the agent's joints. To get you started, you could take the rotation and position of each of the joints in InitializeAgent, and then restore them in AgentReset. Also, zero out the velocity and angular velocity of the rigidbodies.

    I don't see anything in the documentation or examples about calling Done in Update, so it may be recommended or even required for it to be in AgentAction to behave as expected. Might as well move everything out of Update and into AgentAction.

    Also, you may want to use transform.localEulerAngles in your feature vector, which has 3 components, (xyz) instead of transform.localRotation, which has 4 components (xyzw). Otherwise, you should not omit the w component of localRotation.

    Altogether, it might look like this:

    using MLAgents;
    using System;
    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;
    
    public class BalanceAgent : Agent
    {
        private BalancingArea area;
        public GameObject waist;
        public GameObject buttR;
        public GameObject buttL;
        public GameObject thighR;
        public GameObject thighL;
        public GameObject legR;
        public GameObject legL;
        public GameObject footR;
        public GameObject footL;
        public GameObject goal;
    
        private List<GameObject> gameObjectsToReset;
        private List<Rigidbody> rigidbodiesToReset;
        private List<Vector3> initEulers;
        private List<Vector3> initPositions;
    
        // private float buttR = 0f;
    
    
        public override void InitializeAgent() {
            base.InitializeAgent();
            area = GetComponentInParent<BalancingArea>();
    
            gameObjectsToReset= new List<GameObject>(new GameObject[]{
                    waist, buttR, buttL, thighR, thighL, legR, legL,
                    footR, footL});
            rigidbodiesToReset = new List<Rigidbody>();
            initEulers = new List<Vector3>();
            initPositions = new List<Vector3>();
    
            foreach (GameObject g in gameObjectsToReset) {
                rigidbodiesToReset.Add(g.GetComponent<Rigidbody>());
                initEulers.Add(g.transform.eulerAngles);
                initPositions.Add(g.transform.position);
            }
        }
    
        public override void AgentReset() {
            for (int i = 0 ; i < gameObjectsToReset.Count ; i++) {
                Transform t = gameObjectsToReset[i].transform;
                t.position = initPositions[i];
                t.eulerAngles = initEulers[i];
    
                Rigidbody r = rigidbodiesToReset[i];
                r.velocity = Vector3.zero;
                r.angularVelocity = Vector3.zero;
            } 
        }
    
        public override void AgentAction(float[] vectorAction) {
    
            int buttRDir = 0;
            int buttRVec = (int)vectorAction[0];
            switch (buttRVec) {
                case 3:
                    buttRDir = 0;
                    break;
                case 1:
                    buttRDir = -1;
                    break;
                case 2:
                    buttRDir = 1;
                    break;
            }
            buttR.transform.Rotate(0, buttRDir, 0);
    
            int buttLDir = 0;
            int buttLVec = (int)vectorAction[1];
            switch (buttLVec) {
                case 3:
                    buttLDir = 0;
                    break;
                case 1:
                    buttLDir = -1;
                    break;
                case 2:
                    buttLDir = 1;
                    break;
            }
            buttL.transform.Rotate(0, buttLDir, 0);
    
            int thighRDir = 0;
            int thighRVec = (int)vectorAction[2];
            switch (thighRVec) {
                case 3:
                    thighRDir = 0;
                    break;
                case 1:
                    thighRDir = -1;
                    break;
                case 2:
                    thighRDir = 1;
                    break;
            }
            thighR.transform.Rotate(0, thighRDir, 0);
    
            int thighLDir = 0;
            int thighLVec = (int)vectorAction[3];
            switch (thighLVec) {
                case 3:
                    thighLDir = 0;
                    break;
                case 1:
                    thighLDir = -1;
                    break;
                case 2:
                    thighLDir = 1;
                    break;
            }
            thighL.transform.Rotate(0, thighLDir, 0);
    
            int legRDir = 0;
            int legRVec = (int)vectorAction[4];
            switch (legRVec) {
                case 3:
                    legRDir = 0;
                    break;
                case 1:
                    legRDir = -1;
                    break;
                case 2:
                    legRDir = 1;
                    break;
            }
            legR.transform.Rotate(0, legRDir, 0);
    
            int legLDir = 0;
            int legLVec = (int)vectorAction[5];
            switch (legLVec) {
                case 3:
                    legLDir = 0;
                    break;
                case 1:
                    legLDir = -1;
                    break;
                case 2:
                    legLDir = 1;
                    break;
            }
            legL.transform.Rotate(0, legLDir, 0);
    
            int footRDir = 0;
            int footRVec = (int)vectorAction[6];
            switch (footRVec) {
                case 3:
                    footRDir = 0;
                    break;
                case 1:
                    footRDir = -1;
                    break;
                case 2:
                    footRDir = 1;
                    break;
            }
            footR.transform.Rotate(0, footRDir, 0);
    
            int footLDir = 0;
            int footLVec = (int)vectorAction[7];
            switch (footLVec) {
                case 3:
                    footLDir = 0;
                    break;
                case 1:
                    footLDir = -1;
                    break;
                case 2:
                    footLDir = 1;
                    break;
            }
            footL.transform.Rotate(0, footLDir, 0);
    
    
    
            //buttR = vectorAction[0]; //Right or none
            //if (buttR == 2) buttR = -1f; //Left
    
            if (waist.transform.position.y > -1.3) {
                AddReward(.1f);
            }
            else {
                AddReward(-.02f);
            }
    
            if (waist.transform.position.y <= -3) {
                Done();
                AddReward(-.1f);
            }
        }
    
        public override void CollectObservations() {
            AddVectorObs(waist.transform.localEulerAngles.y);
            AddVectorObs(buttR.transform.localEulerAngles.x);
            AddVectorObs(buttL.transform.localEulerAngles.x);
            AddVectorObs(thighR.transform.localEulerAngles.y);
            AddVectorObs(thighL.transform.localEulerAngles.y);
            AddVectorObs(legR.transform.localEulerAngles.y);
            AddVectorObs(legL.transform.localEulerAngles.y);
            AddVectorObs(footR.transform.localEulerAngles.y);
            AddVectorObs(footL.transform.localEulerAngles.y);
    
            AddVectorObs(waist.GetComponent<Rigidbody>().freezeRotation);
    
            AddVectorObs(waist.transform.position);
        }
    }
    

    Finally, make sure you set your BalanceAgent's Max Step to something large enough to see if the agent will fail, maybe 500 or 1000 for starters.

    <code>Max Step</code> is editable in the inspector