Search code examples
performancerustnested-loopsslowdownmagic-square

Why is my Rust code suddenly almost freezing for slightly larger values?


I am making a fun program to find magic squares of squares. I have 9 nested for_each loops corresponding to each of the 3x3 numbers of the square. I use the constant LIMIT to set how many square numbers are in SQUARE_NUMBERS.

When LIMIT is increased beyond say 100 it starts to slow down slightly. All the way to 297 it slows down slightly, but as soon as LIMIT is 298 it slows to a crawl suddenly and when I log on the inside loop, the variable fifth is only being incremented slowly. For LIMIT <=297 this program takes at most 23 seconds on my laptop, but LIMIT = 298 I have let run for hours and it hasn't finished.

What is going on here? I am wondering if all the nested loops is causing a certain optimization to not take place if LIMIT is 298 but not if LIMIT is 297.

main.rs

mod square;

use std::{collections::HashMap, process::exit, time::Instant};

use rayon::prelude::*;
use square::{numbers_are_unique, numbers_make_sum, sums_are_equal};

const fn generate_square_numbers<const COUNT: usize>() -> [usize; COUNT] {
    let mut numbers: [usize; COUNT] = [0usize; COUNT];

    let mut counter: usize = 0;
    while counter < COUNT {
        numbers[counter] = (counter + 1) * (counter + 1);
        counter += 1;
    }

    numbers
}

fn get_most_frequent_total(square_numbers: &[usize]) -> usize {
    let total_iterations: usize = square_numbers.len().pow(3);
    let mut current_iteration: usize = 0;

    let mut totals_and_counts: HashMap<usize, usize> = HashMap::new();

    for &first in square_numbers.iter() {
        for &second in square_numbers.iter() {
            for &third in square_numbers.iter() {
                let total: usize = first + second + third;

                *totals_and_counts.entry(total).or_insert(0) += 1;

                current_iteration += 1;
                let progress: f64 = (current_iteration as f64 / total_iterations as f64) * 100.0;

                if current_iteration % (total_iterations / 1000) == 0 {
                    println!("Progress: {:.1}%", progress);
                }
            }
        }
    }

    let mut totals: Vec<(&usize, &usize)> = totals_and_counts.iter().collect();

    totals.sort_by(|a: &(&usize, &usize), b: &(&usize, &usize)| a.1.cmp(b.1));

    totals.last().map(|(&total, _)| total).unwrap()
}

fn main() {
    let start_time: Instant = Instant::now();

    const LIMIT: usize = 297; // this runs just fine with any value <=297
    const SQUARE_NUMBERS: [usize; LIMIT] = generate_square_numbers();

    let most_frequent_total: usize = get_most_frequent_total(&SQUARE_NUMBERS);

    println!("The most frequent total is {most_frequent_total}.");

    SQUARE_NUMBERS.iter().for_each(|first: &usize| {
        SQUARE_NUMBERS.par_iter().for_each(|second: &usize| {
            SQUARE_NUMBERS.iter().for_each(|third: &usize| {
                SQUARE_NUMBERS.iter().for_each(|fourth: &usize| {
                    SQUARE_NUMBERS.iter().for_each(|fifth: &usize| {
                        SQUARE_NUMBERS.iter().for_each(|sixth: &usize| {
                            SQUARE_NUMBERS.iter().for_each(|seventh: &usize| {
                                SQUARE_NUMBERS.iter().for_each(|eighth: &usize| {
                                    SQUARE_NUMBERS.iter().for_each(|ninth: &usize| {
                                        if numbers_make_sum(
                                            *first,
                                            *second,
                                            *third,
                                            *fourth,
                                            *fifth,
                                            *sixth,
                                            *seventh,
                                            *eighth,
                                            *ninth,
                                            most_frequent_total,
                                        ) && sums_are_equal(
                                            *first, *second, *third, *fourth, *fifth, *sixth,
                                            *seventh, *eighth, *ninth,
                                        ) && numbers_are_unique(
                                            *first, *second, *third, *fourth, *fifth, *sixth,
                                            *seventh, *eighth, *ninth,
                                        ) {
                                            println!(
                                                "{:?}",
                                                [
                                                    first, second, third, fourth, fifth, sixth,
                                                    seventh, eighth, ninth
                                                ]
                                            );
                                            exit(0);
                                        }
                                    });
                                });
                            });
                        });
                    });
                });
            });
        });
        println!("{} / {}", (*first as f32).sqrt(), LIMIT);
    });

    let end_time: Instant = Instant::now();

    println!("Elapsed time: {:?}", end_time - start_time);
}

square.rs

#[inline(always)]
pub(crate) fn numbers_make_sum(
    first: usize,
    second: usize,
    third: usize,
    fourth: usize,
    fifth: usize,
    sixth: usize,
    seventh: usize,
    eighth: usize,
    ninth: usize,
    sum: usize,
) -> bool {
    if first + second + third != sum {
        return false;
    }
    if fourth + fifth + sixth != sum {
        return false;
    }
    if seventh + eighth + ninth != sum {
        return false;
    }

    if first + fourth + seventh != sum {
        return false;
    }
    if second + fifth + eighth != sum {
        return false;
    }
    if third + sixth + ninth != sum {
        return false;
    }

    if first + fifth + ninth != sum {
        return false;
    }
    if third + fifth + seventh != sum {
        return false;
    }

    true
}

#[inline(always)]
pub(crate) fn numbers_are_unique(
    first: usize,
    second: usize,
    third: usize,
    fourth: usize,
    fifth: usize,
    sixth: usize,
    seventh: usize,
    eighth: usize,
    ninth: usize,
) -> bool {
    first != second
        && first != third
        && first != fourth
        && first != fifth
        && first != sixth
        && first != seventh
        && first != eighth
        && first != ninth
        && second != third
        && second != fourth
        && second != fifth
        && second != sixth
        && second != seventh
        && second != eighth
        && second != ninth
        && third != fourth
        && third != fifth
        && third != sixth
        && third != seventh
        && third != eighth
        && third != ninth
        && fourth != fifth
        && fourth != sixth
        && fourth != seventh
        && fourth != eighth
        && fourth != ninth
        && fifth != sixth
        && fifth != seventh
        && fifth != eighth
        && fifth != ninth
        && sixth != seventh
        && sixth != eighth
        && sixth != ninth
        && seventh != eighth
        && seventh != ninth
        && eighth != ninth
}

#[inline(always)]
pub(crate) fn sums_are_equal(
    first: usize,
    second: usize,
    third: usize,
    fourth: usize,
    fifth: usize,
    sixth: usize,
    seventh: usize,
    eighth: usize,
    ninth: usize,
) -> bool {
    if (first + second + third) != (fourth + fifth + sixth)
        || (fourth + fifth + sixth) != (seventh + eighth + ninth)
    {
        return false;
    }

    if (seventh + eighth + ninth) != (first + fourth + seventh)
        || (first + fourth + seventh) != (second + fifth + eighth)
        || (second + fifth + eighth) != (third + sixth + ninth)
    {
        return false;
    }

    (third + sixth + ninth) == (first + fifth + ninth)
        && (first + fifth + ninth) == (seventh + fifth + third)
}

Solution

  • I have 9 nested for_each loops

    This is the cause of your problem. This many nested for loops is absolutely absurd. In order to get this program functional you will need to flatten your algorithm's structure in order to reduce the amount of overall work it is doing.

    O(n^9) is never going to complete for even modestly sized numbers.

    You will want to look into how you can cull the paths taken when doing calculations.