Search code examples
pythonpython-3.xpytorchgenetic-algorithmpygad

Why does PyGad fitness_function not work when inside of a class?


I am trying to train a genetic algorithm but for some reason it does not work when it's stored inside of a class. I have two equivalent pieces of code but the one stored inside of a class fails. It returns this..

raise ValueError("The fitness function must accept 2 parameters:
1) A solution to calculate its fitness value.
2) The solution's index within the population.

The passed fitness function named '{funcname}' accepts {argcount} parameter(s).".format(funcname=fitness_func.__code__.co_name, argcount=fitness_func.__code__.co_argcount))
ValueError: The fitness function must accept 2 parameters:
1) A solution to calculate its fitness value.
2) The solution's index within the population.

The passed fitness function named 'fitness_func' accepts 3 parameter(s).

Here is the simplified version of the one that doesnt work.

import torch
import torch.nn as nn
import pygad.torchga
import pygad

class NN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        self.linear4 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        x = self.linear4(x)
        return x

class Coin:
    def __init__(self):
        self.NeuralNet = NN(1440, 1440, 3)

    def fitness_func(self, solution, solution_idx):
        return 0

    def trainModel(self):

        torch_ga = pygad.torchga.TorchGA(model=self.NeuralNet, num_solutions=10)

        ga_instance = pygad.GA(num_generations=10,
                               num_parents_mating=2,
                               initial_population=torch_ga.population_weights,
                               fitness_func=self.fitness_func)

        ga_instance.run()

if __name__ == "__main__":
    coin = Coin()
    coin.trainModel()

Here is the simplified version of the one that does work.

import torch
import torch.nn as nn
import pygad.torchga
import pygad

class NN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        self.linear4 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        x = self.linear4(x)
        return x
    
def fitness_func(solution, solution_idx):
    return 0

def trainModel():

    NeuralNet = NN(1440, 1440, 3)

    torch_ga = pygad.torchga.TorchGA(model=NeuralNet, num_solutions=10)

    ga_instance = pygad.GA(num_generations=10,
                            num_parents_mating=2,
                            initial_population=torch_ga.population_weights,
                            fitness_func=fitness_func)

    ga_instance.run()

if __name__ == "__main__":
    trainModel()

Both of these should work the same but they don't


Solution

  • When you look at the pygad code you can see it's explicitly checking that the fitness function has exactly two parameters:

            # Check if the fitness function accepts 2 paramaters.
            if (fitness_func.__code__.co_argcount == 2):
                self.fitness_func = fitness_func
            else:
                self.valid_parameters = False
                raise ValueError("The fitness function must accept 2 parameters:\n1) A solution to calculate its fitness value.\n2) The solution's index within the population.\n\nThe passed fitness function named '{funcname}' accepts {argcount} parameter(s).".format(funcname=fitness_func.__code__.co_name, argcount=fitness_func.__code__.co_argcount))
    

    So if you want to use it in a class you'll need to make it a static method so you aren't required to pass in self:

    @staticmethod
    def fitness_func(solution, solution_idx):
        return 0