Search code examples
pythonnumpyfractals

Growing a numpy array


I have a ruleset of 3x256 rules. Each rule maps to a 3x3 grid of values, which in turn themselves are rules.

Example rules:

0 -> [[0,0,0],[0,1,0],[0,0,0]]  
1 -> [[1,1,1],[0,0,0],[1,1,1]]

Seed:

[[0]]

After 1 iteration:

[[0,0,0],
[0,1,0],
[0,0,0]]

After 2 iterations:

[[0,0,0,0,0,0,0,0,0],
[0,1,0,0,1,0,0,1,0],
[0,0,0,0,0,0,0,0,0],
[0,0,0,1,1,1,0,0,0],
[0,1,0,0,0,0,0,1,0],
[0,0,0,1,1,1,0,0,0],
[0,0,0,0,0,0,0,0,0],
[0,1,0,0,1,0,0,1,0],
[0,0,0,0,0,0,0,0,0]]

Now I have a working implementation, however, it's the slowest function in my script. I'm wondering if there is a more pythonic and more efficient way to rewrite this function.

def decode(rules,fractal_iterations,seed):
    final_seed_matrix = np.zeros((3,3**fractal_iterations,3**fractal_iterations))

    for i in range(dimensions):
        seed_matrix = np.array([[seed]])
        for j in range(fractal_iterations):
            size_y = seed_matrix.shape[0]
            size_x = seed_matrix.shape[1]
            new_matrix = np.zeros((size_y*rule_size_sqrt,size_x*rule_size_sqrt))

            for y in range(size_y):
                for x in range(size_x):
                    seed_value = seed_matrix[y,x]
                    new_matrix[y*rule_size_sqrt : y*rule_size_sqrt+rule_size_sqrt, x*rule_size_sqrt : x*rule_size_sqrt+rule_size_sqrt] = rules[int(seed_value),i]

            seed_matrix = new_matrix
        final_seed_matrix[i] = seed_matrix

    return np.moveaxis(final_seed_matrix,0,-1)

Solution

  • Here is an optimized version that uses advanced indexing to select and patch together all rules in one indexing step. This creates a 4D array with the appropriate rule at the position of the pixel it replaces. Flattening that to 2D is then a matter of swapping the middle axes and reshaping. It appears to give the same result as yours, but significantly faster (only tested for integer rules so far):

    results equal: True
    OP       : 24.883304461836815 ms
    optimized: 1.093490980565548 ms
    

    Code:

    import numpy as np
    
    dimensions = 3
    rule_size_sqrt = 3
    
    def decode(rules,fractal_iterations,seed):
        final_seed_matrix = np.zeros((3,3**fractal_iterations,3**fractal_iterations))
    
        for i in range(dimensions):
            seed_matrix = np.array([[seed]])
            for j in range(fractal_iterations):
                size_y = seed_matrix.shape[0]
                size_x = seed_matrix.shape[1]
                new_matrix = np.zeros((size_y*rule_size_sqrt,size_x*rule_size_sqrt))
    
                for y in range(size_y):
                    for x in range(size_x):
                        seed_value = seed_matrix[y,x]
                        new_matrix[y*rule_size_sqrt : y*rule_size_sqrt+rule_size_sqrt, x*rule_size_sqrt : x*rule_size_sqrt+rule_size_sqrt] = rules[int(seed_value),i]
    
                seed_matrix = new_matrix
            final_seed_matrix[i] = seed_matrix
    
        return np.moveaxis(final_seed_matrix,0,-1)
    
    def decode_fast(rules, fractal_iterations, seed):
        rules_int = rules.astype(int)
        seed = np.array([[seed]], dtype=int)
        res = np.empty((3**fractal_iterations, 3**fractal_iterations, dimensions),
                       dtype=rules.dtype)
        for i in range(dimensions):
            grow = seed
            for j in range(1, fractal_iterations):
                grow = rules_int[grow, i].swapaxes(1, 2).reshape(3**j, -1)
            grow = rules[grow, i].swapaxes(1, 2).reshape(3**fractal_iterations, -1)
            res[..., i] = grow
        return res
    
    rules = np.random.randint(0, 4, (4, dimensions, 3, 3))
    seed = 1
    fractal_iterations = 5
    print('results equal:', np.all(decode(rules, fractal_iterations, seed) == decode_fast(rules, fractal_iterations, seed)))
    
    from timeit import repeat
    print('OP       :', min(repeat('decode(rules, fractal_iterations, seed)', globals=globals(), number=50))*20, 'ms')
    print('optimized:', min(repeat('decode_fast(rules, fractal_iterations, seed)', globals=globals(), number=50))*20, 'ms')