Search code examples
iteratorjuliagenerator

Reducing memory allocation of a generator in Julia


I am trying to reduce the memory allocation of an inner loop in my code. Below the part that is not working as expected.

using Random 
using StatsBase
using BenchmarkTools
using Distributions

a_dist = Distributions.DiscreteUniform(1, 99)
v_dist = Distributions.DiscreteUniform(1, 2)
population_size = 10000
population = [rand(a_dist, population_size) rand(v_dist, population_size)]


find_all_it3(f::Function, A) = (p[2] for p in eachrow(A) if f(p[1]))

@btime begin 
    c_pool = find_all_it3(x -> (x < 5), population)
    c_pool_dict = countmap(c_pool, alg=:dict)
end


@btime begin
    c_pool_indexes = findall(x -> (x < 5) ,  view(population, :, 1))
    c_pool_dict = countmap(population[c_pool_indexes, 2], alg=:dict)
end

I was hoping that the generator (find_all_it3) would not need to allocate much memory. however as per the btime output it seems that there is an allocation for each loop.

  98.040 μs (10006 allocations: 625.64 KiB)
  18.894 μs (18 allocations: 11.95 KiB)

Now in my scenario the speed and allocation of the findall eventually become an issue, hence I was trying to find a better alternative through generator/iterators so that less allocation occur; is there a way to do that? Are there options to consider?


Solution

  • I don't have an explaination for it but here are the results of a few tests I made

    • The best time is achieved with view(population, :, 1) .< 5 (test4)
    • using broadcast! reduces allocations a bit (test5)
    • the best way to reduce allocation is to do your own loop (test6)
    using BenchmarkTools
    using StatsBase
    
    population_size = 10000
    population = [rand(1:99, population_size) rand(1:2, population_size)]
    
    find_all_it(f::Function, A) = (p[2] for p in eachrow(A) if f(p[1]))
    
    function test1(population)
        c_pool = find_all_it(x -> x < 5, population)
        c_pool_dict = countmap(c_pool, alg=:dict)
    end
    
    function test3(population)
        c_pool_indexes = findall(x -> x < 5,  view(population, :, 1))
        c_pool_dict = countmap(view(population,c_pool_indexes, 2), alg=:dict)
    end
    
    function test4(population)
        c_pool_indexes = view(population, :, 1) .< 5
        c_pool_dict = countmap(view(population,c_pool_indexes, 2), alg=:dict)
    end
    
    function test5(c_pool_indexes, population)
        broadcast!(<, c_pool_indexes, view(population, :, 1), 5)
        c_pool_dict = countmap(view(population,c_pool_indexes, 2), alg=:dict)
    end
    
    function test6(population)
        d = Dict{Int,Int}()
        for i in eachindex(view(population, :, 1))
            if population[i, 1] < 5
                d[population[i,2]] = 1 + get(d,population[i,2],0)
            end
        end
        return d
    end
    
    julia> @btime test1(population);
      68.200 μs (10004 allocations: 625.59 KiB)
    
    julia> @btime test3(population);
      14.800 μs (14 allocations: 9.00 KiB)
    
    julia> @btime test4(population);
      7.250 μs (8 allocations: 9.33 KiB)
    
    julia> temp = zeros(Bool, population_size);
    
    julia> @btime test5(temp, population);
      16.599 μs (5 allocations: 3.78 KiB)
    
    julia> @btime test6(population);
      11.299 μs (4 allocations: 608 bytes)