Search code examples
genetic-algorithmjenetics

Genetic Algorithm Fitness and Cross Over selection


I'm attempting to use a genetic algorithm to arrange a random list of seats for a given section in a classroom so that all the groups in that classroom sit together.

Here is my current attempt, I'm unclear about the best way to improve this algorithm as this is very very early days in my GA learnings.

I'm using the Java library Jenetics, which has been pretty awesome at getting me started, here is the Engine.

        final Engine<EnumGene<Seat>, Double> ENGINE = Engine
                .builder(SeatingFitness::fitness, encoding)
                .populationSize(200)
                .selector(new RouletteWheelSelector<>())
                .alterers(
                        new PartiallyMatchedCrossover<>(0.2),
                        new Mutator<>(0.01)
                )
                .optimize(Optimize.MINIMUM)
                .build();

The encoding looks like this

    static final ISeq<Seat> seats = ISeq.of(createSeats());
    static final Codec<ISeq<Seat>, EnumGene<Seat>> encoding = Codecs.ofPermutation(seats);

    public static List<Seat> createSeats() {
        return new ArrayList<>(Arrays.asList(
                new Seat("Group C", 1),
                new Seat("Group A", 2),
                new Seat("Group B", 3)
                .....more seats here....)
    }

My fitness function can be improved for sure, I'm not using any libraries, so any suggestions here would be great, but it looks like this.

Essentially what I'm doing is just finding the x and y coordinates of each seat in the group and calculating how far each one is from the others in the group, and summing those values up. The lower value the better, hence the Optimize.MINIMUM in the Engine

private static int numberOfSeatsInRow = 8;

    public static double fitness(final ISeq<Seat> seats) {

        AtomicDouble score = new AtomicDouble();
        // Group together all the seats that belong to a particular group.
        Map<String, List<Seat>> grouping = seats.stream().collect(groupingBy(Seat::getGroup));

        grouping.forEach((group, groupsSeats) -> {
            // Find the location in the overall list of the seats for the group.
            if (!group.equals(Seat.EMPTY_SEAT)) {
                List<Integer> indexOfSeatInOverallList = groupsSeats.stream().map(seats::indexOf).collect(Collectors.toList());
                if (indexOfSeatInOverallList.size() > 2) {
                    // Get the first element positioned correctly on the x and y axis

                    double totalCalculated = indexOfSeatInOverallList.stream().reduce(0, (subTotal, currentElement) -> {

                        int xReferenceCoordinate = calculateXCoordinate(currentElement);
                        int yReferenceCoordinate = calculateYCoordinate(currentElement);

                        double totalDistance = 0;
                        int multiplier = groupsSeats.size() <= numberOfSeatsInRow ? 10 : 500;
                        for (Integer integer : indexOfSeatInOverallList) {
                            int xSecondary = calculateXCoordinate(integer);
                            int ySecondary = calculateYCoordinate(integer);
                            if (ySecondary != yReferenceCoordinate) {
                                totalDistance += multiplier * Math.abs(yReferenceCoordinate - ySecondary);
                            }
                            totalDistance += calculateDistanceBetweenTwoPoints(xReferenceCoordinate, yReferenceCoordinate, xSecondary, ySecondary);
                        }

                        return (int) totalDistance;
                    });

                    score.getAndAdd(totalCalculated);
                }
            }

        });
        return score.get();
    }

    private static int calculateXCoordinate(int positionInList) {
        int xPosition = positionInList % numberOfSeatsInRow;
        if (xPosition == 0) {
            xPosition = numberOfSeatsInRow;
        }
        return xPosition;
    }

    private static int calculateYCoordinate(int positionInList) {
        int xPosition = positionInList % numberOfSeatsInRow;
        int yPosition = positionInList / numberOfSeatsInRow;
        if (xPosition == 0) {
            yPosition = yPosition - 1;
        }

        return yPosition + 1;
    }

    private static double calculateDistanceBetweenTwoPoints(int x1, int y1, int x2, int y2) {
       // https://dqydj.com/2d-distance-calculator/
        double xValue = Math.pow((x2 - x1), 2);
        double yValue = Math.pow((y2 - y1), 2);
        return Math.sqrt(xValue + yValue);
    }

See the results image below, as you can see it's pretty good (although it takes about 3 minutes to run to produce a proper result).

Results of 3 iterations


Solution

  • I had a look at the fitness function. Some things you are calculating for every fitness function call, can be calculated once.

    private static final ISeq<Seat> SEATS = ISeq.of(
        new Seat("Group C", 1),
        new Seat("Group A", 2),
        new Seat("Group B", 3)
    );
    
    private static final Map<String, List<Seat>> SEAT_GROUPS = SEATS.stream()
        .collect(groupingBy(Seat::getGroup));
    

    The SEAT_GROUPS map is defined by the seats list and will not change. If I'm right, the reduce function in your fitness function is ignoring the previously calculated distance.

    double totalCalculated = indexOfSeatInOverallList.stream()
        .reduce(0, (subTotal, currentElement) -> {
                // subTotal is ignored in your code, but should be added to the result.
                return (int) totalDistance + subTotal;
            })
    

    Your calculateDistanceBetweenTwoPoints can be implemented as

    double distance(final int x1, final int y1, final int x2, final int y2) {
        // sqrt(x^2 + y^2)
        return Math.hypot(x2 - x1, y2 - y1);
    }
    

    My "cleaned" version will look like this.

    private static final int SEATS_PER_ROW = 8;
    
    private static final ISeq<Seat> SEATS = ISeq.of(
        new Seat("Group C", 1),
        new Seat("Group A", 2),
        new Seat("Group B", 3)
    );
    
    private static final Map<String, List<Seat>> SEAT_GROUPS = SEATS.stream()
        .collect(groupingBy(Seat::getGroup));
    
    public static double fitness(final ISeq<Seat> seats) {
        double score = 0;
    
        for (var entry : SEAT_GROUPS.entrySet()) {
            final var group = entry.getKey();
            final var groupsSeats = entry.getValue();
            final int multiplier = groupsSeats.size() <= SEATS_PER_ROW ? 10 : 500;
    
            if (!group.equals(Seat.EMPTY_SEAT)) {
                final int[] indexes = groupsSeats.stream()
                    .mapToInt(seats::indexOf)
                    .toArray();
    
                if (indexes.length > 2) {
                    final double dist = IntStream.of(indexes)
                        .reduce(0, (a, b) -> toDistance(multiplier, indexes, a, b));
    
                    score += dist;
                }
            }
        }
    
        return score;
    }
    
    private static int toDistance(
        final int multiplier,
        final int[] indexes,
        final int sum,
        final int index
    ) {
        final int x1 = toX(index);
        final int y = toY(index);
    
        int total = 0;
        for (int i : indexes) {
            final int x2 = toX(i);
            final int y2 = toY(i);
            if (y2 != y) {
                total += multiplier*Math.abs(y - y2);
            }
            total += distance(x1, y, x2, y2);
        }
    
        return sum + total;
    }
    
    private static double distance(final int x1, final int y1, final int x2, final int y2) {
        // sqrt(x^2 + y^2)
        return Math.hypot(x2 - x1, y2 - y1);
    }
    
    private static int toX(final int index) {
        int x = index%SEATS_PER_ROW;
        if (x == 0) {
            x = SEATS_PER_ROW;
        }
        return x;
    }
    
    private static int toY(final int index) {
        final int x = index%SEATS_PER_ROW;
        int y = index/SEATS_PER_ROW;
        if (x == 0) {
            y = y - 1;
        }
        return y + 1;
    }