Search code examples
c#unity-game-enginetensorboardreinforcement-learningml-agent

Why is my AI model trains but doesn't evolve - ML Agents


Created a simple game in unity where ball should hit the targets without hitting the walls. So, started training and the results was too bad. The ball is just collecting one of the 4 targets. But EndEpisode() happens when it collects the last target.

Screenshot of the scene and the balls path throughout the training of 1650,000 steps(if im not wrong, since I called it a generation for every 10,000 steps of training.)

The ball doesn't even tries to hit the second target. What is wrong with my code?

I've even tried with RayPerceptionSensor3D replacing the sphere with a cylinder so that it doesn't roll over and disturb the rayperceptionSensor3d. But it gives even much worse results.

using System.Security.Cryptography;
using System.Data.SqlTypes;
using System.Security;
using System.Runtime.InteropServices;
using System.Net.Sockets;
using System.ComponentModel.Design.Serialization;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using MLAgents.Sensors;
using TMPro;

public class MazeRoller : Agent
{

    Rigidbody rBody;
    Vector3 ballpos;
    void Start () {
        rBody = GetComponent<Rigidbody>();
        ballpos = rBody.transform.position;
    }


    public TextMeshPro text;
    public TextMeshPro miss;
    public TextMeshPro hit;
    int count=0,c=0,h=0,m=0;

    int boxescollect=0;

    public Transform Target;
    public Transform st1;
    public Transform st2;
    public Transform st3;

    public override void OnEpisodeBegin()
    {
        rBody.angularVelocity = Vector3.zero;
        rBody.velocity = Vector3.zero;
        rBody.transform.position = ballpos;
        boxescollect=0;

        st1.GetComponent<Renderer> ().enabled = true;
        st1.GetComponent<Collider> ().enabled = true;

        st2.GetComponent<Renderer> ().enabled = true;
        st2.GetComponent<Collider> ().enabled = true;

        st3.GetComponent<Renderer> ().enabled = true;
        st3.GetComponent<Collider> ().enabled = true;
    }


    void OnCollisionEnter(Collision collision)
    {
        if(collision.gameObject.name == "Target")
        {
            if(st1.GetComponent<Renderer> ().enabled==true || st2.GetComponent<Renderer> ().enabled==true || st3.GetComponent<Renderer> ().enabled==true)
            {
                SetReward(-3.0f+(float)(boxescollect));
            }

            SetReward(2.0f);

            h++;
            hit.SetText(h+"");

            EndEpisode();
        }

        else if(collision.gameObject.name == "Target1")
        {
            boxescollect++;
            AddReward(0.2f);
            st1.GetComponent<Renderer> ().enabled = false;
            st1.GetComponent<Collider> ().enabled = false;
        }

        else if(collision.gameObject.name == "Target2")
        {
            boxescollect++;
            AddReward(0.4f);
            st2.GetComponent<Renderer> ().enabled = false;
            st2.GetComponent<Collider> ().enabled = false;
        }

        else if(collision.gameObject.name == "Target3")
        {
            boxescollect++;
            AddReward(0.6f);
            st3.GetComponent<Renderer> ().enabled = false;
            st3.GetComponent<Collider> ().enabled = false;

        }

        //collision.gameObject.name == "wall1"||collision.gameObject.name == "wall2"||collision.gameObject.name == "wall3"||collision.gameObject.name == "wall4"||collision.gameObject.name == "wall5"||collision.gameObject.name == "wall6"||collision.gameObject.name == "wall7"

        else if(collision.gameObject.tag == "wall")
        {

            if(st1.GetComponent<Renderer> ().enabled==true || st2.GetComponent<Renderer> ().enabled==true || st3.GetComponent<Renderer> ().enabled==true)
            {
                AddReward(-3.0f+(float)(boxescollect));
            }

            SetReward(-1.0f);
            m++;
            miss.SetText(m+"");
            EndEpisode();
        }


    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // Target and Agent positions
        sensor.AddObservation(Target.position);
        sensor.AddObservation(this.transform.position);

        sensor.AddObservation(boxescollect);
        sensor.AddObservation(boxescollect-3);

        sensor.AddObservation(st1.position);
        sensor.AddObservation(st2.position);
        sensor.AddObservation(st3.position);


        float dist = Vector3.Distance(Target.position,this.transform.position);
        //Distance between Agent and target
        sensor.AddObservation(dist);

        float d1 = Vector3.Distance(st1.position,this.transform.position);
        //Distance between Agent and target
        sensor.AddObservation(d1);


        float d2 = Vector3.Distance(st2.position,this.transform.position);
        //Distance between Agent and target
        sensor.AddObservation(d2);


        float d3 = Vector3.Distance(st3.position,this.transform.position);
        //Distance between Agent and target
        sensor.AddObservation(d3);

        // Agent velocity
        sensor.AddObservation(rBody.velocity.x);
        sensor.AddObservation(rBody.velocity.z);
    }

    public float speed = 10;
    public override void OnActionReceived(float[] vectorAction)
    {
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = vectorAction[0];
        controlSignal.z = vectorAction[1];
        //speed = vectorAction[2];
        rBody.AddForce(controlSignal * speed);
        //speed=0;

        count++;

        if(count==10000)
        {

            count=0;
            h=0;
            m=0;
            c++;
            miss.SetText(m+"");
            hit.SetText(h+"");
            text.SetText(c+"");
        }

    }

    public override float[] Heuristic()
    {
        var action = new float[2];
        action[0] = Input.GetAxis("Horizontal");
        action[1] = Input.GetAxis("Vertical");
        return action;
    }
}

weird Graph of the training - tensorboard This is what I get after the training in tensorboard.


Solution

  • You are ending the episode with only one goal accomplished and not for the full accomplishment of your goal. So your chart looks messy, it's ending the episode too early, the agent doesn't understand its purpose.

    I think you could add some new rules. -if the agent recoils with his steps, he is punished -the agent will receive a punishment if he does not take all 4 cubes before the end of the episode

    the episode should end only if the agent completes the task of taking all 4 cubes (reward) or if the agent took a number of steps without achieving his goal (punished)

    I hope that can help. I feel my bad English.

    ___edit 2:___

    It is very likely that your problem has similar characteristics to those described in this document. (specifically page 28)

    https://repositorio.upct.es/bitstream/handle/10317/8094/tfg-san-est.pdf?sequence=1&isAllowed=y (It is in Spanish, sorry, but google translator will give you a fairly accurate translation.)

    The problem in the document is identical to yours, the agent had problems with the corners, when he reaches the corner he returns to the starting point, that only happened with the corners.

    Have you tried changing the scenery? Maybe ... try without the walls, to see if the agent is really looking for "all" the targets and go deeper into the problem.

    the graph is the least of it, it is only a representation. You will not have a good graph if the agent is not fulfilling his mission.