Search code examples
c#unity-game-enginemachine-learningartificial-intelligenceml-agent

ML Agents - Multiple agents break the training


I've been working on a self-balancing agent that strives to keep its waist at a certain height. Recently, I upgraded the "thighs" to allow for 3 axis to rotate instead of 2 that I previously had. After doing this, and modifying the ml agents code to allow for child sensors, the agents now seem to no longer work with more than one agent/area. I'm not sure why this is happening. To be clear, the only working agent is acting more "explosive" than usual. When it's by itself, it is much calmer at trying to balance. Maybe I messed up something else in the process? If anyone has any ideas, I'm for anything. Thank you!

Broken ML Agents Inspector of agent

Agent Script :

using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents.Sensor;
using Random = UnityEngine.Random;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject floor;
    public GameObject sensor;
    public GameObject waist;
    public GameObject wFront;           //Used to check balance of waist.
    public GameObject wBack;           //Used to check balance of waist.
    public GameObject hipR;
    public GameObject hipL;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    public float bodyMoveSensitivity = 3f;

    public static GameObject[] bodyParts = new GameObject[11];
    public static HingeJoint[] hingeParts = new HingeJoint[11];
    public static JointLimits[] jntLimParts = new JointLimits[11];

    public static Vector3[] posStart = new Vector3[11];
    public static Vector3[] eulerStart = new Vector3[11];

    public void Start() {
        bodyParts = new GameObject[] { waist /*0*/, buttR /*1*/, buttL /*2*/, thighR /*3*/, thighL /*4*/, legR /*5*/, legL /*6*/, footR /*7*/, footL /*8*/, hipR /*9*/, hipL /*10*/};

        for (int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
            if (bodyParts[i].GetComponent<HingeJoint>() != null) {
                hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
                hingeParts[i].limits = jntLimParts[i];
            }
        }
    }

    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();
    }

    public override void AgentReset() {
        floor.transform.eulerAngles = new Vector3(Random.Range(-15, 15), 0, Random.Range(-15, 15));             //Floor rotation

        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
        }
        waist.transform.eulerAngles = new Vector3(0, Random.Range(0, 360), 0);

        jntLimParts[1].max = 0;
        jntLimParts[1].min = jntLimParts[2].max - 1;
        hingeParts[1].limits = jntLimParts[2];

        jntLimParts[2].max = 0;
        jntLimParts[2].min = jntLimParts[2].max - 1;
        hingeParts[2].limits = jntLimParts[2];

        jntLimParts[3].max = 15;
        jntLimParts[3].min = jntLimParts[3].max - 1;
        hingeParts[3].limits = jntLimParts[3];

        jntLimParts[4].max = 15;
        jntLimParts[4].min = jntLimParts[4].max - 1;
        hingeParts[4].limits = jntLimParts[4];

        jntLimParts[5].max  = -15;
        jntLimParts[5].min = jntLimParts[5].max - 1;
        hingeParts[5].limits = jntLimParts[5];

        jntLimParts[6].max = -15;
        jntLimParts[6].min = jntLimParts[6].max - 1;
        hingeParts[6].limits = jntLimParts[6];

        jntLimParts[7].max = 15;
        jntLimParts[7].min = jntLimParts[7].max - 1;
        hingeParts[7].limits = jntLimParts[7];

        jntLimParts[8].max = 15;
        jntLimParts[8].min = jntLimParts[8].max - 1;
        hingeParts[8].limits = jntLimParts[8];

        jntLimParts[9].max = 0;
        jntLimParts[9].min = jntLimParts[9].max - 1;
        hingeParts[9].limits = jntLimParts[9];

        jntLimParts[10].max = 0;
        jntLimParts[10].min = jntLimParts[10].max - 1;
        hingeParts[10].limits = jntLimParts[10];
    }

    public override void AgentAction(float[] vectorAction) {

        float buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 1:
                buttRDir = 0;
                break;
            case 2:
                buttRDir = bodyMoveSensitivity;
                break;
            case 3:
                buttRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[1].max < 60 && jntLimParts[1].min > -60) {
            jntLimParts[1].max += buttRDir;
            jntLimParts[1].min = jntLimParts[1].max - 1;
            hingeParts[1].limits = jntLimParts[1];
        }
        else {
            if (jntLimParts[1].min <= -60) {
                jntLimParts[1].max = -58;

            }
            else if (jntLimParts[1].max >= 60) {
                jntLimParts[1].max = 59;
            }
            jntLimParts[1].min = jntLimParts[1].max - 1;
        }

        float buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 1:
                buttLDir = 0;
                break;
            case 2:
                buttLDir = bodyMoveSensitivity;
                break;
            case 3:
                buttLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[2].max < 60 && jntLimParts[2].min > -60) {
            jntLimParts[2].max += buttLDir;
            jntLimParts[2].min = jntLimParts[2].max - 1;
            hingeParts[2].limits = jntLimParts[2];
        }
        else {
            if (jntLimParts[2].min <= -60) {
                jntLimParts[2].max = -58;

            }
            else if (jntLimParts[2].max >= 60) {
                jntLimParts[2].max = 59;
            }
            jntLimParts[2].min = jntLimParts[2].max - 1;
        }

        float thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 1:
                thighRDir = 0;
                break;
            case 2:
                thighRDir = bodyMoveSensitivity;
                break;
            case 3:
                thighRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[3].max < 80 && jntLimParts[3].min > -80) {
            jntLimParts[3].max += thighRDir;
            jntLimParts[3].min = jntLimParts[3].max - 1;
            hingeParts[3].limits = jntLimParts[3];
        }
        else {
            if (jntLimParts[3].min <= -80) {
                jntLimParts[3].max = -78;

            }
            else if (jntLimParts[3].max >= 80) {
                jntLimParts[3].max = 79;
            }
            jntLimParts[3].min = jntLimParts[3].max - 1;
        }

        float thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 1:
                thighLDir = 0;
                break;
            case 2:
                thighLDir = bodyMoveSensitivity;
                break;
            case 3:
                thighLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[4].max < 80 && jntLimParts[4].min > -80) {
            jntLimParts[4].max += thighLDir;
            jntLimParts[4].min = jntLimParts[4].max - 1;
            hingeParts[4].limits = jntLimParts[4];
        }
        else {
            if (jntLimParts[4].min <= -80) {
                jntLimParts[4].max = -78;

            }
            else if (jntLimParts[4].max >= 80) {
                jntLimParts[4].max = 79;
            }
            jntLimParts[4].min = jntLimParts[4].max - 1;
        }

        float legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 1:
                legRDir = 0;
                break;
            case 2:
                legRDir = bodyMoveSensitivity;
                break;
            case 3:
                legRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[5].max < 5 && jntLimParts[5].min > -80) {
            jntLimParts[5].max += legRDir;
            jntLimParts[5].min = jntLimParts[5].max - 1;
            hingeParts[5].limits = jntLimParts[5];
        }
        else {
            if (jntLimParts[5].min <= -80) {
                jntLimParts[5].max = -78;

            }
            else if (jntLimParts[5].max >= 5) {
                jntLimParts[5].max = 4;
            }
            jntLimParts[5].min = jntLimParts[5].max - 1;
        }

        float legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 1:
                legLDir = 0;
                break;
            case 2:
                legLDir = bodyMoveSensitivity;
                break;
            case 3:
                legLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[6].max < 5 && jntLimParts[6].min > -80) {
            jntLimParts[6].max += legLDir;
            jntLimParts[6].min = jntLimParts[6].max - 1;
            hingeParts[6].limits = jntLimParts[6];
        }
        else {
            if (jntLimParts[6].min <= -80) {
                jntLimParts[6].max = -78;

            }
            else if (jntLimParts[6].max >= 5) {
                jntLimParts[6].max = 4;
            }
            jntLimParts[6].min = jntLimParts[6].max - 1;
        }

        float footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 1:
                footRDir = 0;
                break;
            case 2:
                footRDir = bodyMoveSensitivity;
                break;
            case 3:
                footRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[7].max < 50 && jntLimParts[7].min > -50) {
            jntLimParts[7].max += footRDir;
            jntLimParts[7].min = jntLimParts[7].max - 1;
            hingeParts[7].limits = jntLimParts[7];
        }
        else {
            if (jntLimParts[7].min <= -50) {
                jntLimParts[7].max = -48;

            }
            else if (jntLimParts[7].max >= 50) {
                jntLimParts[7].max = 49;
            }
            jntLimParts[7].min = jntLimParts[7].max - 1;
        }

        float footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 1:
                footLDir = 0;
                break;
            case 2:
                footLDir = bodyMoveSensitivity;
                break;
            case 3:
                footLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[8].max < 50 && jntLimParts[8].min > -50) {
            jntLimParts[8].max += footLDir;
            jntLimParts[8].min = jntLimParts[8].max - 1;
            hingeParts[8].limits = jntLimParts[8];
        }
        else {
            if (jntLimParts[8].min <= -50) {
                jntLimParts[8].max = -48;

            }
            else if (jntLimParts[8].max >= 50) {
                jntLimParts[8].max = 49;
            }
            jntLimParts[8].min = jntLimParts[8].max - 1;
        }

        float hipRDir = 0;
        int hipRVec = (int)vectorAction[9];
        switch (hipRVec) {
            case 1:
                hipRDir = 0;
                break;
            case 2:
                hipRDir = bodyMoveSensitivity;
                break;
            case 3:
                hipRDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[9].max < 45 && jntLimParts[9].min > -45) {
            jntLimParts[9].max += hipRDir;
            jntLimParts[9].min = jntLimParts[9].max - 1;
            hingeParts[9].limits = jntLimParts[9];
        }
        else {
            if (jntLimParts[9].min <= -45) {
                jntLimParts[9].max = -43;

            }
            else if (jntLimParts[9].max >= 45) {
                jntLimParts[9].max = 44;
            }
            jntLimParts[9].min = jntLimParts[9].max - 1;
        }

        float hipLDir = 0;
        int hipLVec = (int)vectorAction[10];
        switch (hipRVec) {
            case 1:
                hipLDir = 0;
                break;
            case 2:
                hipLDir = bodyMoveSensitivity;
                break;
            case 3:
                hipLDir = -bodyMoveSensitivity;
                break;
        }
        if (jntLimParts[10].max < 45 && jntLimParts[10].min > -45) {
            jntLimParts[10].max += hipLDir;
            jntLimParts[10].min = jntLimParts[10].max - 1;
            hingeParts[10].limits = jntLimParts[10];
        }
        else {
            if (jntLimParts[10].min <= -45) {
                jntLimParts[10].max = -43;

            }
            else if (jntLimParts[10].max >= 45) {
                jntLimParts[10].max = 44;
            }
            jntLimParts[10].min = jntLimParts[10].max - 1;
        }

        float waistDir = 0;
        int waistVec = (int)vectorAction[8];
        switch (footLVec) {
            case 1:
                waistDir = 0;
                break;
            case 2:
                waistDir = bodyMoveSensitivity;
                break;
            case 3:
                waistDir = -bodyMoveSensitivity;
                break;
        }
        bodyParts[0].transform.Rotate(0, waistDir, 0);




        sensor.transform.eulerAngles = new Vector3(0, 0, 0);

        if ( wFront.transform.position.y < wBack.transform.position.y-1 || wFront.transform.position.y > wBack.transform.position.y + 1 || buttR.transform.position.y < buttL.transform.position.y - 1 || buttR.transform.position.y > buttL.transform.position.y + 1) {                //Maintain waist rotation.
            AddReward(-.2f);
        }
        else {
            AddReward(.01f);
        }

        if (waist.transform.position.y <= -3) {             //Maintain waist height.
            AddReward(-.2f);
            Done();
        }
        else {
            AddReward(.01f);
        }

        if(waist.transform.position.x > posStart[0].x + 2 || waist.transform.position.x < posStart[0].x - 2 || waist.transform.position.z > posStart[0].z + 2 || waist.transform.position.z < posStart[0].z - 2) {              //Maintain waist position.
            AddReward(-.2f);
        }
        else {
            AddReward(.01f);
        }
    }

    public override void CollectObservations() {

        for (int i = 0; i < bodyParts.Length; i++) {
            AddVectorObs(bodyParts[i].transform.position);
            AddVectorObs(bodyParts[i].transform.rotation);
            AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().velocity);
            AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().angularVelocity);
            AddVectorObs(jntLimParts[i].max);
            AddVectorObs(jntLimParts[i].min);
            AddVectorObs(wFront.transform.position.y);
            AddVectorObs(wFront.transform.rotation);
            AddVectorObs(wBack.transform.position.y);
            AddVectorObs(wBack.transform.rotation);

        }
    }
}

Solution

  • Unfortunately, I had to restart from scratch for this issue to be resolved. After I did that, made sure I had a more updated version of the MLAgents, it all worked just fine again.